import torch
import torch.nn.functional as F
from .loss import GeneratorLoss, DiscriminatorLoss
__all__ = ['minimax_generator_loss', 'minimax_discriminator_loss', 'MinimaxGeneratorLoss', 'MinimaxDiscriminatorLoss']
def minimax_generator_loss(dgz, nonsaturating=True, reduction='mean'):
if nonsaturating:
target = torch.ones_like(dgz)
return F.binary_cross_entropy_with_logits(dgz, target,
reduction=reduction)
else:
target = torch.zeros_like(dgz)
return -1.0 * F.binary_cross_entropy_with_logits(dgz, target,
reduction=reduction)
def minimax_discriminator_loss(dx, dgz, label_smoothing=0.0, reduction='elementwise_mean'):
target_ones = torch.ones_like(dgz) * (1.0 - label_smoothing)
target_zeros = torch.zeros_like(dx)
loss = F.binary_cross_entropy_with_logits(dx, target_ones,
reduction=reduction)
loss += F.binary_cross_entropy_with_logits(dgz, target_zeros,
reduction=reduction)
return loss
[docs]class MinimaxGeneratorLoss(GeneratorLoss):
r"""Minimax game generator loss from the original GAN paper `"Generative Adversarial Networks
by Goodfellow et. al." <https://arxiv.org/abs/1406.2661>`_
The loss can be described as:
.. math:: L(G) = log(1 - D(G(z)))
The nonsaturating heuristic is also supported:
.. math:: L(G) = -log(D(G(z)))
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`z` : A sample from the noise prior
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
If ``sum`` the elements of the output are summed.
override_train_ops (function, optional): Function to be used in place of the default ``train_ops``
nonsaturating(bool, optional): Specifies whether to use the nonsaturating heuristic
loss for the generator.
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the mean of the output.
If ``sum`` the elements of the output will be summed.
"""
def __init__(self, reduction='mean', nonsaturating=True, override_train_ops=None):
super(MinimaxGeneratorLoss, self).__init__(reduction, override_train_ops)
self.nonsaturating = nonsaturating
[docs] def forward(self, dgz):
r"""Computes the loss for the given input.
Args:
dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
Returns:
scalar if reduction is applied else Tensor with dimensions (N, \*).
"""
return minimax_generator_loss(dgz, self.nonsaturating, self.reduction)
[docs]class MinimaxDiscriminatorLoss(DiscriminatorLoss):
r"""Minimax game discriminator loss from the original GAN paper `"Generative Adversarial Networks
by Goodfellow et. al." <https://arxiv.org/abs/1406.2661>`_
The loss can be described as:
.. math:: L(D) = -[log(D(x)) + log(1 - D(G(z)))]
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`x` : A sample from the data distribution
- :math:`z` : A sample from the noise prior
Args:
label_smoothing (float, optional): The factor by which the labels (1 in this case) needs
to be smoothened. For example, label_smoothing = 0.2 changes the value of the real
labels to 0.8.
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the mean of the output.
If ``sum`` the elements of the output will be summed.
override_train_ops (function, optional): A function is passed to this argument,
if the default ``train_ops`` is not to be used.
"""
def __init__(self, label_smoothing=0.0, reduction='mean', override_train_ops=None):
super(MinimaxDiscriminatorLoss, self).__init__(reduction, override_train_ops)
self.label_smoothing = label_smoothing
[docs] def forward(self, dx, dgz):
r"""Computes the loss for the given input.
Args:
dx (torch.Tensor) : Output of the Discriminator with real data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
Returns:
scalar if reduction is applied else Tensor with dimensions (N, \*).
"""
return minimax_discriminator_loss(dx, dgz, label_smoothing=self.label_smoothing, reduction=self.reduction)