Menu

Module providing Consciousness Exploration Tools for PyTorch.

Source code for consciousnet.losses.tdvae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Definition of the Temporal Difference Variational Auto-Encoder (TD-VAE) loss.
"""

# Imports
import torch.nn.functional as func
from torch.distributions import kl_divergence


[docs]class TDVAELoss(object): """ TDVAE Loss function. This loss needs intermediate layers outputs. Use a callback function to set the 'layer_outputs' class parameter before each evaluation of the loss function. If you use an interface this parameter is updated automatically. """
[docs] def __init__(self, obs_loss=None): self.layer_outputs = None self.obs_loss = obs_loss or func.binary_cross_entropy
def get_params(self): if self.layer_outputs is None: raise ValueError( "This loss needs intermediate layers outputs. Please register " "an appropriate callback.") t2 = self.layer_outputs["t2"] q_S_t1_li_z = self.layer_outputs["q_S_t1_li_z"] p_B_t1_li_z = self.layer_outputs["p_B_t1_li_z"] p_B_t2_li_z = self.layer_outputs["p_B_t2_li_z"] p_T_t2_li_z = self.layer_outputs["p_T_t2_li_z"] t2_li_z = self.layer_outputs["t2_li_z"] return t2, q_S_t1_li_z, p_B_t1_li_z, p_B_t2_li_z, p_T_t2_li_z, t2_li_z def __call__(self, p_D_t2_x, x): """ Calculate the jumpy TD-VAE loss, which corresponds to equations (6) and (8) in the reference paper. """ # Get params (t2, q_S_t1_li_z, p_B_t1_li_z, p_B_t2_li_z, p_T_t2_li_z, t2_li_z) = self.get_params() n_layers = len(q_S_t1_li_z) xt2 = x[:, t2, :] # KL divergence between z distribution at time t1 based on # variational distribution (inference model) and z distribution at # time t1 based on belief. kl_qs_pb_loss = 0 for i in range(n_layers): kl_qs_pb_loss += kl_divergence(q_S_t1_li_z[i], p_B_t1_li_z[i]) kl_qs_pb_loss = kl_qs_pb_loss.sum(dim=-1) # The following four terms estimate the KL divergence between the z # distribution at time t2 based on variational distribution # (inference model) and z distribution at time t2 based on transition. # In contrast with the above KL divergence for z distribution at time # t1, this KL divergence can not be calculated analytically because # the transition distribution depends on z_t1, which is sampled # after z_t2. Therefore, the KL divergence is estimated using samples. kl_shift_qb_pt_loss = 0 for i in range(n_layers): # - state log probabilty at time t2 based on belief kl_shift_qb_pt_loss += p_B_t2_li_z[i].log_prob(t2_li_z[i]) # - state log probabilty at time t2 based on transition kl_shift_qb_pt_loss -= p_T_t2_li_z[i].log_prob(t2_li_z[i]) kl_shift_qb_pt_loss = kl_shift_qb_pt_loss.sum(dim=-1) # Observation log probability at time t2 _loss = self.obs_loss(p_D_t2_x, xt2, reduction="none").sum(dim=-1) _loss_optimal = self.obs_loss(xt2, xt2, reduction="none").sum(dim=-1) obs_loss = _loss - _loss_optimal # Final loss loss = kl_qs_pb_loss + kl_shift_qb_pt_loss + obs_loss loss = loss.mean() return loss, {"kl_qs_pb_loss": kl_qs_pb_loss, "kl_shift_qb_pt_loss": kl_shift_qb_pt_loss, "obs_loss": obs_loss}

Follow us

© 2021, consciousnet developers