Menu

Module providing Consciousness Exploration Tools for PyTorch.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class consciousnet.models.gmvae.GMVAE(input_dim, latent_dim, n_mix_components, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, dropout=0, temperature=1, gen_bias_init=0.0, prior_gmm=None, decoder=None, encoder_y=None, encoder_gmm=None, random_seed=None)[source]

The Gaussian Mixture VAE architecture.

Meta-GMVAE: Mixture of Gaussian VAE for Unsupervised Meta-Learning Dong Bok Lee, ICLR 2021.

Gaussian Mixture VAE: Lessons in Variational Inference, Generative Models, and Deep Nets: http://ruishu.io/2016/12/25/gmvae

Deep Unsupervised Clustering with Gaussian Mixture Variational Autoencoders Nat Dilokthanakul, arXiv 2017.

__init__(input_dim, latent_dim, n_mix_components, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, dropout=0, temperature=1, gen_bias_init=0.0, prior_gmm=None, decoder=None, encoder_y=None, encoder_gmm=None, random_seed=None)[source]

Init class.

Parameters

input_dim : int

the input size.

latent_dim : int,

the size of the stochastic latent state of the GMVAE.

n_mix_components : int

the number of mixture components.

dense_hidden_dims : list of int, default None

the sizes of the hidden layers of the fully connected network used to condition the distribution on the inputs. If None, then the default is a single-layered dense network.

sigma_min : float, default 0.001

the minimum value that the standard deviation of the distribution over the latent state can take.

raw_sigma_bias : float, default 0.25

a scalar that is added to the raw standard deviation output from the neural networks that parameterize the prior and approximate posterior. Useful for preventing standard deviations close to zero.

dropout : float, default 0

define the dropout rate.

temperature : float, default 1

degree of how approximately discrete the distribution is. The closer to 0, the more discrete and the closer to infinity, the more uniform.

gen_bias_init : float, default 0

a bias to added to the raw output of the fully connected network that parameterizes the generative distribution. Useful for initalising the mean to a sensible starting point e.g. mean of training set.

prior_gmm : @callable, default None

a callable that implements the prior distribution p(z | y) Must accept as argument the y discrete variable and return a tf.distributions.MultivariateNormalDiag distribution.

decoder: : @callable, default None

a callable that implements the generative distribution p(x | z). Must accept as arguments the encoded latent state z and return a subclass of tf.distributions.Distribution that can be used to evaluate the log_prob of the targets.

encoder_y: : @callable, default None

a callable that implements the inference q(y | x) over the discrete latent variable y.

encoder_gmm: : @callable, default None

a callable that implements the inference q(z | x, y) over the continuous latent variable z.

random_seed : int, default None

the seed for the random operations.

decoder(z)[source]

Computes the generative distribution p(x | z).

Parameters

z : torch.Tensor (num_samples, mix_components, latent_size)

the stochastic latent state z.

Returns

p(x | z) : Bernoulli (batch_size, data_size)

a Bernouilli distribution.

encoder_gmm(x, y)[source]

Computes the inference distribution q(z | x, y).

Parameters

x : torch.Tensor (batch_size, data_size)

the input data.

y : torch.Tensor (batch_size, mix_components)

discrete variable.

Returns

q(z | x, y) : MultivariateNormal (batch_size, latent_size)

a Multivariate Normal Diag distribution.

encoder_y(x)[source]

Computes the inference distribution q(y | x).

Parameters

x : torch.Tensor (batch_size, data_size)

the input data to the inference network.

Returns

q(y | x) : RelaxedOneHotCategorical (batch_size, mix_components)

a relaxed one hot Categorical distribution.

forward(x)[source]

The forward method.

generate_sample_data(z=None, num_samples=1)[source]

Generates mean sample data from the model.

Can provide latent variable ‘z’ to generate data for this point in the latent space, else draw from prior.

Parameters

z : torch.Tensor (num_samples, mix_components, latent_size)

the stochastic latent state z.

Returns

recon : torch.Tensor (batch_size, data_size)

the reconstructed mean samples data.

generate_samples(num_samples, clusters=None)[source]

Samples components from the static latent GMM prior.

Parameters

num_samples : int

number of samples to draw from the static GMM prior.

clusters : list of int, default None

if desired, can sample from a specific batch of clusters.

Returns

z : Tensor (num_samples, mix_components, latent_size)

representing samples drawn from each component of the GMM if clusters is None else if clusters the Tensor is of shape (num_samples, batch_size, latent_size) where batch_size is the first dimension of clusters, dependening on how many were supplied.

prior_gmm(y)[source]

Computes the GMM prior distribution p(z | y).

Parameters

y : torch.Tensor (batch_size, mix_components)

the discrete intermediate variable y.

Returns

p(z | y) : MultivariateNormal (batch_size, latent_size)

a GMM distribution.

reconstruct(x)[source]

Reconstruct the data from the model.

Parameters

x : torch.Tensor (batch_size, data_size)

the input data.

Returns

recon : torch.Tensor (batch_size, data_size)

the reconstruucted data.

transform(x)[source]

Transform inputs ‘x’ to yield mean latent code.

Parameters

x : torch.Tensor (batch_size, data_size)

the input data.

Returns

z : torch.Tensor (num_samples, mix_components, latent_size)

the stochastic latent state z.

Follow us

© 2021, consciousnet developers