Skip to content

Commit

Permalink
Change the default weight_params in WeightedRips (#595)
Browse files Browse the repository at this point in the history
* Change the default weight_params in WeightedRips, do not allow None
  • Loading branch information
wreise authored Jul 8, 2021
1 parent 1d757d5 commit b429a66
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions gtda/homology/simplicial.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin):
is a parameter (see `weight_params`). If a callable, it must return
non-negative 1D arrays.
weight_params : dict, optional, default: ``None``
weight_params : dict, optional, default: ``{}``
Additional parameters for the weighted filtration. ``"p"`` determines
the power to be used in computing edge weights from vertex weights. It
can be one of ``1``, ``2`` or ``np.inf`` and defaults to ``1``. If
Expand Down Expand Up @@ -525,7 +525,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin):
"of": {"type": int, "in": Interval(0, np.inf, closed="left")}
},
"weights": {"type": (str, FunctionType)},
"weight_params": {"type": (dict, type(None))},
"weight_params": {"type": dict},
"collapse_edges": {"type": bool},
"coeff": {"type": int, "in": Interval(2, np.inf, closed="left")},
"max_edge_weight": {"type": Real},
Expand All @@ -534,7 +534,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin):
}

def __init__(self, metric="euclidean", metric_params={},
homology_dimensions=(0, 1), weights="DTM", weight_params=None,
homology_dimensions=(0, 1), weights="DTM", weight_params={},
collapse_edges=False, coeff=2, max_edge_weight=np.inf,
infinity_values=None, reduced_homology=True, n_jobs=None):
self.metric = metric
Expand Down Expand Up @@ -616,7 +616,7 @@ def fit(self, X, y=None):
self.effective_weight_params_.update({"n_neighbors": 3, "r": 2})
else:
key = "general"
if self.weight_params is not None:
if self.weight_params:
self.effective_weight_params_.update(self.weight_params)
validate_params(self.effective_weight_params_,
_AVAILABLE_RIPS_WEIGHTS[key])
Expand Down

0 comments on commit b429a66

Please sign in to comment.