Source code for torchgan.losses.mutualinfo

import torch
from .loss import GeneratorLoss, DiscriminatorLoss
from ..utils import reduce

__all__ = ['mutual_information_penalty', 'MutualInformationPenalty']

def mutual_information_penalty(c_dis, c_cont, dist_dis, dist_cont, reduction='mean'):
    log_probs = torch.Tensor([torch.mean(dist.log_prob(c)) for dist, c in
                             zip((dist_dis, dist_cont), (c_dis, c_cont))])
    return reduce(-1.0 * log_probs, reduction)

[docs]class MutualInformationPenalty(GeneratorLoss, DiscriminatorLoss): r"""Mutual Information Penalty as defined in `"InfoGAN : Interpretable Representation Learning by Information Maximising Generative Adversarial Nets by Chen et. al." <https://arxiv.org/abs/1606.03657>`_ paper The loss is the variational lower bound of the mutual information between the latent codes and the generator distribution and is defined as .. math:: L(G,Q) = log(Q|x) where - :math:`x` is drawn from the generator distribution G(z,c) - :math:`c` drawn from the latent code prior :math:`P(c)` Args: lambd (float, optional): The scaling factor for the loss. reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the mean of the output. If ``sum`` the elements of the output will be summed. override_train_ops (function, optional): A function is passed to this argument, if the default ``train_ops`` is not to be used. """ def __init__(self, lambd=1.0, reduction='mean', override_train_ops=None): super(MutualInformationPenalty, self).__init__(reduction, override_train_ops) self.lambd = lambd
[docs] def forward(self, c_dis, c_cont, dist_dis, dist_cont): r"""Computes the loss for the given input. Args: c_dis (int): The discrete latent code sampled from the prior. c_cont (int): The continuous latent code sampled from the prior. dist_dis (torch.distributions.Distribution): The auxilliary distribution :math:`Q(c|x)` over the discrete latent code output by the discriminator. dist_cont (torch.distributions.Distribution): The auxilliary distribution :math:`Q(c|x)` over the continuous latent code output by the discriminator. Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ log_probs = torch.Tensor([torch.mean(dist.log_prob(c)) for dist, c in zip((dist_dis, dist_cont), (c_dis, c_cont))]) return reduce(-1.0 * log_probs, self.reduction)
[docs] def train_ops(self, generator, discriminator, optimizer_generator, optimizer_discriminator, dis_code, cont_code, device, batch_size): if self.override_train_ops is not None: self.override_train_ops(generator, discriminator, optimizer_generator, optimizer_discriminator, dis_code, cont_code, device, batch_size) else: noise = torch.randn(batch_size, generator.encoding_dims, device=device) optimizer_discriminator.zero_grad() optimizer_generator.zero_grad() fake = generator(noise, dis_code, cont_code) _, dist_dis, dist_cont = discriminator(fake, True) loss = self.forward(dis_code, cont_code, dist_dis, dist_cont) weighted_loss = self.lambd * loss weighted_loss.backward() optimizer_discriminator.step() optimizer_generator.step() return weighted_loss.item()