torchgmm.bayes.GaussianMixture#
- class torchgmm.bayes.GaussianMixture(num_components=1, *, covariance_type='diag', init_strategy='kmeans', init_means=None, convergence_tolerance=0.001, covariance_regularization=1e-06, batch_size=None, trainer_params=None)#
Probabilistic model assuming that data is generated from a mixture of Gaussians.
The mixture is assumed to be composed of a fixed number of components with individual means and covariances. More information on Gaussian mixture models (GMMs) is available on Wikipedia.
Attributes table#
Returns the list of fitted attributes that ought to be saved and loaded. |
|
The fitted PyTorch module with all estimated parameters. |
|
A boolean indicating whether the model converged during training. |
|
The number of iterations the model was fitted for, excluding initialization. |
|
The average per-datapoint negative log-likelihood at the last training step. |
Methods table#
|
Clones the estimator without copying any fitted attributes. |
|
Fits the Gaussian mixture on the provided data, estimating component priors, means and covariances. |
|
Fits the estimator using the provided data and subsequently predicts the labels for the data using the fitted estimator. |
|
Returns the estimator's parameters as passed to the initializer. |
|
Loads the estimator and (if available) the fitted model. |
|
Loads the fitted attributes that are stored at the fitted path. |
|
Initializes this estimator by loading its parameters. |
|
Computes the most likely components for each of the provided datapoints. |
|
Computes a distribution over the components for each of the provided datapoints. |
|
Samples datapoints from the fitted Gaussian mixture. |
|
Saves the estimator to the provided directory. |
|
Saves the fitted attributes of this estimator. |
|
Saves the parameters of this estimator. |
|
Computes the average negative log-likelihood (NLL) of the provided datapoints. |
|
Computes the negative log-likelihood (NLL) of each of the provided datapoints. |
|
Sets the provided values on the estimator. |
|
Returns the trainer as configured by the estimator. |
Attributes#
- GaussianMixture.persistent_attributes#
Returns the list of fitted attributes that ought to be saved and loaded.
By default, this encompasses all annotations.
-
GaussianMixture.model_:
GaussianMixtureModel# The fitted PyTorch module with all estimated parameters.
Methods#
- GaussianMixture.clone()#
Clones the estimator without copying any fitted attributes. All parameters of this estimator are copied via
copy.deepcopy().- Return type:
TypeVar(E, bound= BaseEstimator)- Returns:
The cloned estimator with the same parameters.
- GaussianMixture.fit(data)#
Fits the Gaussian mixture on the provided data, estimating component priors, means and covariances. Parameters are estimated using the EM algorithm.
- Args:
- data: The tabular data to fit on. The dimensionality of the Gaussian mixture is
automatically inferred from this data.
- Return type:
- Returns:
The fitted Gaussian mixture.
- GaussianMixture.fit_predict(data)#
Fits the estimator using the provided data and subsequently predicts the labels for the data using the fitted estimator. It simply chains calls to
fit()andpredict().- Args:
- data: The data to use for fitting and to predict labels for. The data must have the
same type as for the
fit()method.
- GaussianMixture.get_params(deep=True)#
Returns the estimator’s parameters as passed to the initializer.
- Args:
deep: Ignored. For Scikit-learn compatibility.
- classmethod GaussianMixture.load(path)#
Loads the estimator and (if available) the fitted model. This method should only be expected to work to load an estimator that has previously been saved via
save().- Args:
path: The directory from which to load the estimator.
- Return type:
TypeVar(E, bound= BaseEstimator)- Returns:
The loaded estimator, either fitted or not.
- GaussianMixture.load_attributes(path)#
Loads the fitted attributes that are stored at the fitted path. If subclasses overwrite
save_attributes(), this method should also be overwritten.Typically, this method should not be called directly. It is called as part of
load().- Return type:
- Args:
path: The directory from which the parameters should be loaded.
- Raises:
FileNotFoundError – If the no fitted attributes have been stored.:
- classmethod GaussianMixture.load_parameters(path)#
Initializes this estimator by loading its parameters. If subclasses overwrite
save_parameters(), this method should also be overwritten.Typically, this method should not be called directly. It is called as part of
load().- Return type:
TypeVar(E, bound= BaseEstimator)
- Args:
path: The directory from which the parameters should be loaded.
- GaussianMixture.predict(data)#
Computes the most likely components for each of the provided datapoints.
- Args:
data: The datapoints for which to obtain the most likely components.
- Return type:
Tensor- Returns:
A tensor of shape
[num_datapoints]with the indices of the most likely components.- Note:
Use
predict_proba()to obtain probabilities for each component instead of the most likely component only.- Attention:
When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.
- GaussianMixture.predict_proba(data)#
Computes a distribution over the components for each of the provided datapoints.
- Args:
data: The datapoints for which to compute the component assignment probabilities.
- Return type:
Tensor- Returns:
- A tensor of shape
[num_datapoints, num_components]with the assignment probabilities for each component and datapoint. Note that each row of the vector sums to 1, i.e. the returned tensor provides a proper distribution over the components for each datapoint.
- Attention:
When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.
- A tensor of shape
- GaussianMixture.sample(num_datapoints)#
Samples datapoints from the fitted Gaussian mixture.
- Args:
num_datapoints: The number of datapoints to sample.
- Return type:
Tensor- Returns:
A tensor of shape
[num_datapoints, dim]providing the samples.- Note:
This method does not parallelize across multiple processes, i.e. performs no synchronization.
- GaussianMixture.save(path)#
Saves the estimator to the provided directory. It saves a file named
estimator.picklefor the configuration of the estimator and additional files for the fitted model (if applicable). For more information on the files saved for the fitted model or for more customization, look atget_params()andtorchgmm.base.nn.Configurable.save().- Return type:
- Args:
path: The directory to which all files should be saved.
- Note:
This method may be called regardless of whether the estimator has already been fitted.
- Attention:
If the dictionary returned by
get_params()is not JSON-serializable, this method usespicklewhich is not necessarily backwards-compatible.
- GaussianMixture.save_attributes(path)#
Saves the fitted attributes of this estimator. By default, it uses JSON and falls back to
pickle. Subclasses should overwrite this method if non-primitive attributes are fitted.Typically, this method should not be called directly. It is called as part of
save().- Return type:
- Args:
path: The directory to which the fitted attributed should be saved.
- Raises:
NotFittedError – If the estimator has not been fitted.:
- GaussianMixture.save_parameters(path)#
Saves the parameters of this estimator. By default, it uses JSON and falls back to
pickle. It subclasses use non-primitive types as parameters, they should overwrite this method.Typically, this method should not be called directly. It is called as part of
save().- Return type:
- Args:
path: The directory to which the parameters should be saved.
- GaussianMixture.score(data)#
Computes the average negative log-likelihood (NLL) of the provided datapoints.
- Args:
data: The datapoints for which to evaluate the NLL.
- Return type:
- Returns:
The average NLL of all datapoints.
- Note:
See
score_samples()to obtain NLL values for individual datapoints.
- GaussianMixture.score_samples(data)#
Computes the negative log-likelihood (NLL) of each of the provided datapoints.
- Args:
data: The datapoints for which to compute the NLL.
- Return type:
Tensor- Returns:
A tensor of shape
[num_datapoints]with the NLL for each datapoint.- Attention:
When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.
- GaussianMixture.set_params(values)#
Sets the provided values on the estimator. The estimator is returned as well, but the estimator on which this function is called is also modified.
- Args:
values: The values to set.
- Return type:
TypeVar(E, bound= BaseEstimator)- Returns:
The estimator where the values have been set.
- GaussianMixture.trainer(**kwargs)#
Returns the trainer as configured by the estimator. Typically, this method is only called by functions in the estimator.
- Args:
- kwargs: Additional arguments that override the trainer arguments registered in the
initializer of the estimator.
- Return type:
Trainer- Returns:
A fully initialized PyTorch Lightning trainer.
- Note:
This function should be preferred over initializing the trainer directly. It ensures that the returned trainer correctly deals with TorchGMM components that may be introduced in the future.