torchgmm.bayes.GaussianMixture#

class torchgmm.bayes.GaussianMixture(num_components=1, *, covariance_type='diag', init_strategy='kmeans', init_means=None, convergence_tolerance=0.001, covariance_regularization=1e-06, batch_size=None, trainer_params=None)#

Probabilistic model assuming that data is generated from a mixture of Gaussians.

The mixture is assumed to be composed of a fixed number of components with individual means and covariances. More information on Gaussian mixture models (GMMs) is available on Wikipedia.

Attributes table#

persistent_attributes

Returns the list of fitted attributes that ought to be saved and loaded.

model_

The fitted PyTorch module with all estimated parameters.

converged_

A boolean indicating whether the model converged during training.

num_iter_

The number of iterations the model was fitted for, excluding initialization.

nll_

The average per-datapoint negative log-likelihood at the last training step.

Methods table#

clone()

Clones the estimator without copying any fitted attributes.

fit(data)

Fits the Gaussian mixture on the provided data, estimating component priors, means and covariances.

fit_predict(data)

Fits the estimator using the provided data and subsequently predicts the labels for the data using the fitted estimator.

get_params([deep])

Returns the estimator's parameters as passed to the initializer.

load(path)

Loads the estimator and (if available) the fitted model.

load_attributes(path)

Loads the fitted attributes that are stored at the fitted path.

load_parameters(path)

Initializes this estimator by loading its parameters.

predict(data)

Computes the most likely components for each of the provided datapoints.

predict_proba(data)

Computes a distribution over the components for each of the provided datapoints.

sample(num_datapoints)

Samples datapoints from the fitted Gaussian mixture.

save(path)

Saves the estimator to the provided directory.

save_attributes(path)

Saves the fitted attributes of this estimator.

save_parameters(path)

Saves the parameters of this estimator.

score(data)

Computes the average negative log-likelihood (NLL) of the provided datapoints.

score_samples(data)

Computes the negative log-likelihood (NLL) of each of the provided datapoints.

set_params(values)

Sets the provided values on the estimator.

trainer(**kwargs)

Returns the trainer as configured by the estimator.

Attributes#

GaussianMixture.persistent_attributes#

Returns the list of fitted attributes that ought to be saved and loaded.

By default, this encompasses all annotations.

GaussianMixture.model_: GaussianMixtureModel#

The fitted PyTorch module with all estimated parameters.

GaussianMixture.converged_: bool#

A boolean indicating whether the model converged during training.

GaussianMixture.num_iter_: int#

The number of iterations the model was fitted for, excluding initialization.

GaussianMixture.nll_: float#

The average per-datapoint negative log-likelihood at the last training step.

Methods#

GaussianMixture.clone()#

Clones the estimator without copying any fitted attributes. All parameters of this estimator are copied via copy.deepcopy().

Return type:

TypeVar(E, bound= BaseEstimator)

Returns:

The cloned estimator with the same parameters.

GaussianMixture.fit(data)#

Fits the Gaussian mixture on the provided data, estimating component priors, means and covariances. Parameters are estimated using the EM algorithm.

Args:
data: The tabular data to fit on. The dimensionality of the Gaussian mixture is

automatically inferred from this data.

Return type:

GaussianMixture

Returns:

The fitted Gaussian mixture.

GaussianMixture.fit_predict(data)#

Fits the estimator using the provided data and subsequently predicts the labels for the data using the fitted estimator. It simply chains calls to fit() and predict().

Args:
data: The data to use for fitting and to predict labels for. The data must have the

same type as for the fit() method.

Return type:

TypeVar(R_co, covariant=True)

Returns:

The predicted labels. Consult the predict() documentation for more information on the return type.

GaussianMixture.get_params(deep=True)#

Returns the estimator’s parameters as passed to the initializer.

Args:

deep: Ignored. For Scikit-learn compatibility.

Return type:

dict[str, Any]

Returns:

The mapping from init parameters to values.

classmethod GaussianMixture.load(path)#

Loads the estimator and (if available) the fitted model. This method should only be expected to work to load an estimator that has previously been saved via save().

Args:

path: The directory from which to load the estimator.

Return type:

TypeVar(E, bound= BaseEstimator)

Returns:

The loaded estimator, either fitted or not.

GaussianMixture.load_attributes(path)#

Loads the fitted attributes that are stored at the fitted path. If subclasses overwrite save_attributes(), this method should also be overwritten.

Typically, this method should not be called directly. It is called as part of load().

Return type:

None

Args:

path: The directory from which the parameters should be loaded.

Raises:

FileNotFoundError – If the no fitted attributes have been stored.:

classmethod GaussianMixture.load_parameters(path)#

Initializes this estimator by loading its parameters. If subclasses overwrite save_parameters(), this method should also be overwritten.

Typically, this method should not be called directly. It is called as part of load().

Return type:

TypeVar(E, bound= BaseEstimator)

Args:

path: The directory from which the parameters should be loaded.

GaussianMixture.predict(data)#

Computes the most likely components for each of the provided datapoints.

Args:

data: The datapoints for which to obtain the most likely components.

Return type:

Tensor

Returns:

A tensor of shape [num_datapoints] with the indices of the most likely components.

Note:

Use predict_proba() to obtain probabilities for each component instead of the most likely component only.

Attention:

When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.

GaussianMixture.predict_proba(data)#

Computes a distribution over the components for each of the provided datapoints.

Args:

data: The datapoints for which to compute the component assignment probabilities.

Return type:

Tensor

Returns:

A tensor of shape [num_datapoints, num_components] with the assignment

probabilities for each component and datapoint. Note that each row of the vector sums to 1, i.e. the returned tensor provides a proper distribution over the components for each datapoint.

Attention:

When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.

GaussianMixture.sample(num_datapoints)#

Samples datapoints from the fitted Gaussian mixture.

Args:

num_datapoints: The number of datapoints to sample.

Return type:

Tensor

Returns:

A tensor of shape [num_datapoints, dim] providing the samples.

Note:

This method does not parallelize across multiple processes, i.e. performs no synchronization.

GaussianMixture.save(path)#

Saves the estimator to the provided directory. It saves a file named estimator.pickle for the configuration of the estimator and additional files for the fitted model (if applicable). For more information on the files saved for the fitted model or for more customization, look at get_params() and torchgmm.base.nn.Configurable.save().

Return type:

None

Args:

path: The directory to which all files should be saved.

Note:

This method may be called regardless of whether the estimator has already been fitted.

Attention:

If the dictionary returned by get_params() is not JSON-serializable, this method uses pickle which is not necessarily backwards-compatible.

GaussianMixture.save_attributes(path)#

Saves the fitted attributes of this estimator. By default, it uses JSON and falls back to pickle. Subclasses should overwrite this method if non-primitive attributes are fitted.

Typically, this method should not be called directly. It is called as part of save().

Return type:

None

Args:

path: The directory to which the fitted attributed should be saved.

Raises:

NotFittedError – If the estimator has not been fitted.:

GaussianMixture.save_parameters(path)#

Saves the parameters of this estimator. By default, it uses JSON and falls back to pickle. It subclasses use non-primitive types as parameters, they should overwrite this method.

Typically, this method should not be called directly. It is called as part of save().

Return type:

None

Args:

path: The directory to which the parameters should be saved.

GaussianMixture.score(data)#

Computes the average negative log-likelihood (NLL) of the provided datapoints.

Args:

data: The datapoints for which to evaluate the NLL.

Return type:

float

Returns:

The average NLL of all datapoints.

Note:

See score_samples() to obtain NLL values for individual datapoints.

GaussianMixture.score_samples(data)#

Computes the negative log-likelihood (NLL) of each of the provided datapoints.

Args:

data: The datapoints for which to compute the NLL.

Return type:

Tensor

Returns:

A tensor of shape [num_datapoints] with the NLL for each datapoint.

Attention:

When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.

GaussianMixture.set_params(values)#

Sets the provided values on the estimator. The estimator is returned as well, but the estimator on which this function is called is also modified.

Args:

values: The values to set.

Return type:

TypeVar(E, bound= BaseEstimator)

Returns:

The estimator where the values have been set.

GaussianMixture.trainer(**kwargs)#

Returns the trainer as configured by the estimator. Typically, this method is only called by functions in the estimator.

Args:
kwargs: Additional arguments that override the trainer arguments registered in the

initializer of the estimator.

Return type:

Trainer

Returns:

A fully initialized PyTorch Lightning trainer.

Note:

This function should be preferred over initializing the trainer directly. It ensures that the returned trainer correctly deals with TorchGMM components that may be introduced in the future.