Module providing Consciousness Exploration Tools for PyTorch.
Source code for consciousnet.models.tdvae
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Definition of the Temporal Difference Variational Auto-Encoder (TD-VAE) model.
"""
# Imports
import math
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
[docs]class TDVAE(torch.nn.Module):
""" Hierachical Temporal Difference Variational Auto-Encoder with
jumpy predictions.
Temporal Difference Variational Auto-Encoder, Karol Gregor, George
Papamakarios, Frederic Besse, Lars Buesing and Theophane Weber,
ICLR 2019, https://openreview.net/forum?id=S1x4ghC9tQ.
First, let's first go through some definitions which would help
understanding what is going on in the following code:
Observation: the observated variable x.
Belief: as the model is feed with a sequence of observations, x_t, the
model updates its belief state b_t, through a LSTM network. It
is a deterministic function of x_t. We call b_t the belief at
time t instead of belief state.
State: the latent hidden state variable z.
"""
[docs] def __init__(self, x_dim, b_dim, z_dim, t, d, n_layers=2, n_lstm_layers=1,
preproc_dim=None, add_sigmoid=True):
""" Init class.
Parameters
-----------
x_dim: int
the dimension of observed data.
b_dim: int
the belief code dimension.
z_dim: int
the dimension of latent space.
t: int
in jumpy state modeling, t1 can be chosen uniformly from the
sequence U(1,t).
d: int
in jumpy state modeling, t2 − t1 can be chosen uniformly over
some finite range U(1,d).
n_layers: int, default 2
the number of hierachical level in the model.
n_lstm_layers: int, default 1
the number of recurrent layers, eg setting this paramter to 2
would mean stacking two LSTMs together to form a stacked LSTM,
with the second LSTM taking in outputs of the first LSTM and
computing the final results.
preproc_dim: int, default None
the dimension of preprocessed observations. If not specified no
preprocessing is applied.
add_sigmoid: bool, default True
apply sigmoid activation fct to the decoder.
"""
# Inheritance
super(TDVAE, self).__init__()
# Parameters
self.x_dim = x_dim
self.b_dim = b_dim
self.z_dim = z_dim
self.t = t
self.d = d
self.n_layers = n_layers
self.n_lstm_layers = n_lstm_layers
self.preproc_dim = preproc_dim
# Input pre-process layer
if self.preproc_dim is not None:
self.preproc_x = PreprocBlock(
input_size=x_dim, preproc_size=preproc_dim)
else:
self.preproc_dim = x_dim
self.preproc_x = None
# N layer LSTM for aggregating belief states
self.lstm = nn.LSTM(
input_size=self.preproc_dim, hidden_size=b_dim, batch_first=True,
num_layers=n_lstm_layers)
# N layer state model is used. Sampling is done by sampling
# higher layer first.
_b_to_z, _infer_z, _transition_z = [], [], []
for idx in range(n_layers - 1, -1, -1):
extra_dim = (z_dim if idx < (n_layers - 1) else 0)
# - belief to state (b_to_z): this corresponds to the p_B
# distribution in the reference. Weights are shared across time
# but not across layers.
_b_to_z.append(
DBlock(
input_size=(b_dim + extra_dim),
hidden_size=50, output_size=z_dim))
# - given belief and state at time t2, infer the state at time t1
# (infer_z): this corresponds to the q_S distribution in the
# reference.
_infer_z.append(
DBlock(
input_size=(b_dim + n_layers * z_dim + extra_dim),
hidden_size=50, output_size=z_dim))
# - given the state at time t1, model state at time t2 through
# state transition (transition_z): this corresponds to the p_T
# distribution in the reference.
_transition_z.append(
DBlock(
input_size=(n_layers * z_dim + extra_dim),
hidden_size=50, output_size=z_dim))
self.b_to_z = nn.ModuleList(_b_to_z)
self.infer_z = nn.ModuleList(_infer_z)
self.transition_z = nn.ModuleList(_transition_z)
# - state to observation (z_to_x): this corresponds to the p_D
# distribution in the reference.
self.z_to_x = Decoder(
z_size=(n_layers * z_dim), hidden_size=200, x_size=x_dim,
add_sigmoid=add_sigmoid)
[docs] def reparameterize(self, q):
""" The reparametrization trick.
"""
if self.training:
z = q.rsample()
else:
z = q.loc
return z
[docs] def forward(self, x):
""" The forward method.
"""
# - pre-process image x
if self.preproc_x is not None:
x = self.preproc_x(x)
# - aggregate the belief b
b, (h_n, c_n) = self.lstm(x)
# - sample t1 and t2
t1 = np.random.randint(0, self.t)
t2 = t1 + np.random.randint(1, self.d + 1)
# Because the loss is based on variational inference, we need to
# draw samples from the variational distribution in order to estimate
# the loss function.
# - sample a state z at time t2 using the reparametralization trick
# in layer 2 & 1 respectively. The result state is obtained by
# concatenating results from layer 1 and layer 2.
p_B_t2_li_z = []
t2_li_z = []
for i in range(self.n_layers):
if i == 0:
p_B_t2_li_z.append(self.b_to_z[i](
b[:, t2, :]))
else:
p_B_t2_li_z.append(self.b_to_z[i](
torch.cat((b[:, t2, :], t2_li_z[-1]), dim=-1)))
t2_li_z.append(self.reparameterize(p_B_t2_li_z[-1]))
t2_z = torch.cat(t2_li_z[::-1], dim=-1)
# - sample a state at time t1: infer state at time t1 based on states
# at time t2. The result state is obtained by concatenating results
# from layer 1 and layer 2.
q_S_t1_li_z = []
t1_li_z = []
for i in range(self.n_layers):
if i == 0:
q_S_t1_li_z.append(self.infer_z[i](
torch.cat((b[:, t1, :], t2_z), dim=-1)))
else:
q_S_t1_li_z.append(self.infer_z[i](
torch.cat((b[:, t1, :], t2_z, t1_li_z[-1]), dim=-1)))
t1_li_z.append(self.reparameterize(q_S_t1_li_z[-1]))
t1_z = torch.cat(t1_li_z[::-1], dim=-1)
# - compute state distribution at time t1 based on belief at time 1
p_B_t1_li_z = []
for i in range(self.n_layers):
if i == 0:
p_B_t1_li_z.append(self.b_to_z[i](
b[:, t1, :]))
else:
p_B_t1_li_z.append(self.b_to_z[i](
torch.cat((b[:, t1, :], t1_li_z[i - 1]), dim=-1)))
# - compute state distribution at time t2 based on states at time t1
# and state transition
p_T_t2_li_z = []
for i in range(self.n_layers):
if i == 0:
p_T_t2_li_z.append(self.transition_z[i](
t1_z))
else:
p_T_t2_li_z.append(self.transition_z[i](
torch.cat((t1_z, t2_li_z[i - 1]), dim=-1)))
# - compute observation distribution at time t2 based on state at
# time t2
p_D_t2_x = self.z_to_x(t2_z)
return p_D_t2_x, {"b": b, "t1": t1, "t2": t2,
"p_B_t2_li_z": p_B_t2_li_z, "t2_li_z": t2_li_z,
"t2_z": t2_z, "q_S_t1_li_z": q_S_t1_li_z,
"t1_li_z": t1_li_z, "t1_z": t1_z,
"p_B_t1_li_z": p_B_t1_li_z,
"p_T_t2_li_z": p_T_t2_li_z}
[docs] def rollout(self, x, t1, t2):
""" Jumpy rollout.
Parameters
----------
x: Tensor
the input sequences.
t1: int
the time jump number of steps.
t2: int
the prediction interval t2 - t1.
Retruns
-------
rollout_x: Tensor
the predicted frames.
"""
# Compute belief
if self.preproc_x is not None:
x = self.preproc_x(x)
b, (h_n, c_n) = self.lstm(x)
# At time t1-1, we sample a state z based on belief at time t1-1
li_z = []
for i in range(self.n_layers):
if i == 0:
p_B_li_z = self.b_to_z[i](b[:, t1 - 1, :])
else:
p_B_li_z = self.b_to_z[i](
torch.cat((b[:, t1 - 1, :], li_z[-1]), dim=-1))
li_z.append(self.reparameterize(p_B_li_z))
current_z = torch.cat(li_z[::-1], dim=-1)
# Start rollout
rollout_x = []
for k in range(t2 - t1 + 1):
# - predict states after time t1 using state transition
tnext_li_z = []
for i in range(self.n_layers):
if i == 0:
p_T_tnext_li_z = self.transition_z[i](current_z)
else:
p_T_tnext_li_z = self.transition_z[i](
torch.cat((current_z, tnext_li_z[i - 1]), dim=-1))
tnext_li_z.append(self.reparameterize(p_T_tnext_li_z))
next_z = torch.cat(tnext_li_z[::-1], dim=-1)
# - generate an observation x_t1 at time t1 based on sampled
# state z_t1
next_x = self.z_to_x(next_z)
rollout_x.append(next_x)
current_z = next_z
rollout_x = torch.stack(rollout_x, dim=1)
return rollout_x
[docs]class DBlock(nn.Module):
""" A basic building block to parametrize a Normal distribution.
It is corresponding to the D operation in the reference Appendix.
"""
[docs] def __init__(self, input_size, hidden_size, output_size):
super(DBlock, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(input_size, hidden_size)
self.fc_mu = nn.Linear(hidden_size, output_size)
self.fc_logsigma = nn.Linear(hidden_size, output_size)
[docs] def forward(self, input):
t = torch.tanh(self.fc1(input))
t = t * torch.sigmoid(self.fc2(input))
mu = self.fc_mu(t)
logsigma = self.fc_logsigma(t)
return Normal(loc=mu, scale=logsigma.exp().pow(0.5))
[docs]class PreprocBlock(nn.Module):
""" The optional pre-process layer.
"""
[docs] def __init__(self, input_size, preproc_size):
super(PreprocBlock, self).__init__()
self.input_size = input_size
self.fc1 = nn.Linear(input_size, preproc_size)
self.fc2 = nn.Linear(preproc_size, preproc_size)
[docs] def forward(self, input):
t = torch.relu(self.fc1(input))
t = torch.relu(self.fc2(t))
return t
[docs]class Decoder(nn.Module):
""" The decoder layer that converts state to observation.
Because the observation is MNIST image whose elements are values
between 0 and 1, the output of this layer are probabilities of
elements being 1.
"""
[docs] def __init__(self, z_size, hidden_size, x_size, add_sigmoid=True):
super(Decoder, self).__init__()
self.add_sigmoid = add_sigmoid
self.fc1 = nn.Linear(z_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, x_size)
[docs] def forward(self, z):
t = torch.tanh(self.fc1(z))
t = torch.tanh(self.fc2(t))
p = self.fc3(t)
if self.add_sigmoid:
p = torch.sigmoid(p)
return p
Follow us