Source code for torchgan.logging.visualize

import torch
import torchvision
from ..models.model import Generator, Discriminator
from .backends import *
    from tensorboardX import SummaryWriter
    import visdom

__all__ = ['Visualize', 'LossVisualize', 'MetricVisualize',
           'GradientVisualize', 'ImageVisualize']

[docs]class Visualize(object): r"""Base class for all Visualizations. Args: visualize_list (list, optional): List of the functions needed for visualization. visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be manually started at this port else an error will be thrown and the code will crash. This is ignored if ``VISDOM_LOGGING`` is ``0``. log_dir (str, optional): Directory where TensorboardX should store the logs. This is ignored if ``TENSORBOARD_LOGGING`` is ``0``. writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you don't want to start a new SummaryWriter. """ def __init__(self, visualize_list, visdom_port=8097, log_dir=None, writer=None): self.logs = {} for item in visualize_list: name = type(item).__name__ self.logs[name] = [] self.step = 1 if TENSORBOARD_LOGGING == 1: self._build_tensorboard(log_dir, writer) if VISDOM_LOGGING == 1: self._build_visdom(visdom_port) def _build_tensorboard(self, log_dir, writer): r"""Starts the tensorboard logging utilities. Args: log_dir (str, optional): Directory where TensorboardX should store the logs. writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you don't want to start a new SummaryWriter. """ self.writer = SummaryWriter(log_dir) if writer is None else writer def _build_visdom(self, port): r"""Starts the visdom logging utilities. Args: port (int, optional): Port to log using ``visdom``. A deafult server is started at port ``8097``. So manually a new server has to be started if the post is changed. """ self.vis = visdom.Visdom(port=port)
[docs] def step_update(self): r"""Helper function which updates the step at the end of one print iteration. """ self.step += 1
[docs] def log_tensorboard(self): r"""Tensorboard logging function. Needs to be defined in the subclass :raises NotImplementedError: """ raise NotImplementedError
[docs] def log_console(self): r"""Console logging function. Needs to be defined in the subclass :raises NotImplementedError: """ raise NotImplementedError
[docs] def log_visdom(self): r"""Visdom logging function. Needs to be defined in the subclass :raises NotImplementedError: """ raise NotImplementedError
def __call__(self, *args, lock_console=False, lock_tensorboard=False, lock_visdom=False, **kwargs): if not lock_console and CONSOLE_LOGGING == 1: self.log_console(*args, **kwargs) if not lock_tensorboard and TENSORBOARD_LOGGING == 1: self.log_tensorboard(*args, **kwargs) if not lock_visdom and VISDOM_LOGGING == 1: self.log_visdom(*args, **kwargs) self.step_update()
[docs]class LossVisualize(Visualize): r"""This class provides the Visualizations for Generator and Discriminator Losses. Args: visualize_list (list, optional): List of the functions needed for visualization. visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be manually started at this port else an error will be thrown and the code will crash. This is ignored if ``VISDOM_LOGGING`` is ``0``. log_dir (str, optional): Directory where TensorboardX should store the logs. This is ignored if ``TENSORBOARD_LOGGING`` is ``0``. writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you don't want to start a new SummaryWriter. """
[docs] def log_tensorboard(self, running_losses): r"""Tensorboard logging function. This function logs the following: - ``Running Discriminator Loss`` - ``Running Generator Loss`` - ``Running Losses`` - Loss Values of the individual Losses. Args: running_losses (dict): A dict with 2 items namely, ``Running Discriminator Loss``, and ``Running Generator Loss``. """ self.writer.add_scalar("Running Discriminator Loss", running_losses["Running Discriminator Loss"], self.step) self.writer.add_scalar("Running Generator Loss", running_losses["Running Generator Loss"], self.step) self.writer.add_scalars("Running Losses", running_losses, self.step) for name, value in self.logs.items(): val = value[-1] if type(val) is tuple: self.writer.add_scalar('Losses/{}-Generator'.format(name), val[0], self.step) self.writer.add_scalar('Losses/{}-Discriminator'.format(name), val[1], self.step) else: self.writer.add_scalar('Losses/{}'.format(name), val, self.step)
[docs] def log_console(self, running_losses): r"""Console logging function. This function logs the mean ``generator`` and ``discriminator`` losses. Args: running_losses (dict): A dict with 2 items namely, ``Running Discriminator Loss``, and ``Running Generator Loss``. """ for name, val in running_losses.items(): print('Mean {} : {}'.format(name, val))
[docs] def log_visdom(self, running_losses): r"""Visdom logging function. This function logs the following: - ``Running Discriminator Loss`` - ``Running Generator Loss`` - ``Running Losses`` - Loss Values of the individual Losses. Args: running_losses (dict): A dict with 2 items namely, ``Running Discriminator Loss``, and ``Running Generator Loss``. """ self.vis.line([running_losses["Running Discriminator Loss"]], [self.step], win="Running Discriminator Loss", update="append", opts=dict(title="Running Discriminator Loss", xlabel="Time Step", ylabel="Running Loss")) self.vis.line([running_losses["Running Generator Loss"]], [self.step], win="Running Generator Loss", update="append", opts=dict(title="Running Generator Loss", xlabel="Time Step", ylabel="Running Loss")) self.vis.line([[running_losses["Running Discriminator Loss"], running_losses["Running Generator Loss"]]], [self.step], win="Running Losses", update="append", opts=dict(title="Running Losses", xlabel="Time Step", ylabel="Running Loss", legend=["Discriminator", "Generator"])) for name, value in self.logs.items(): val = value[-1] if type(val) is tuple: name1 = "{}-Generator".format(name) name2 = "{}-Discriminator".format(name) self.vis.line([val[0]], [self.step], win=name1, update="append", opts=dict(title=name1, xlabel="Time Step", ylabel="Loss Value")) self.vis.line([val[1]], [self.step], win=name2, update="append", opts=dict(title=name2, xlabel="Time Step", ylabel="Loss Value")) else: self.vis.line([val], [self.step], win=name, update="append", opts=dict(title=name, xlabel="Time Step", ylabel="Loss Value"))
def __call__(self, trainer, **kwargs): running_generator_loss = trainer.loss_information["generator_losses"] /\ trainer.loss_information["generator_iters"] running_discriminator_loss = trainer.loss_information["discriminator_losses"] /\ trainer.loss_information["discriminator_iters"] running_losses = {"Running Discriminator Loss": running_discriminator_loss, "Running Generator Loss": running_generator_loss} super(LossVisualize, self).__call__(running_losses, **kwargs)
[docs]class MetricVisualize(Visualize): r"""This class provides the Visualizations for Metrics. Args: visualize_list (list, optional): List of the functions needed for visualization. visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be manually started at this port else an error will be thrown and the code will crash. This is ignored if ``VISDOM_LOGGING`` is ``0``. log_dir (str, optional): Directory where TensorboardX should store the logs. This is ignored if ``TENSORBOARD_LOGGING`` is ``0``. writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you don't want to start a new SummaryWriter. """
[docs] def log_tensorboard(self): r"""Tensorboard logging function. This function logs the values of the individual metrics. """ for name, value in self.logs.items(): self.writer.add_scalar("Metrics/{}".format(name), value[-1], self.step)
[docs] def log_console(self): r"""Console logging function. This function logs the mean metrics. """ for name, val in self.logs.items(): print('{} : {}'.format(name, val[-1]))
[docs] def log_visdom(self): r"""Visdom logging function. This function logs the values of the individual metrics. """ for name, value in self.logs.items(): self.vis.line([value[-1]], [self.step], win=name, update="append", opts=dict(title=name, xlabel="Time Step", ylabel="Metric Value"))
[docs]class GradientVisualize(Visualize): r"""This class provides the Visualizations for the Gradients. Args: visualize_list (list, optional): List of the functions needed for visualization. visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be manually started at this port else an error will be thrown and the code will crash. This is ignored if ``VISDOM_LOGGING`` is ``0``. log_dir (str, optional): Directory where TensorboardX should store the logs. This is ignored if ``TENSORBOARD_LOGGING`` is ``0``. writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you don't want to start a new SummaryWriter. """ def __init__(self, visualize_list, visdom_port=8097, log_dir=None, writer=None): if visualize_list is None or len(visualize_list) == 0: raise Exception('Gradient Visualizer requires list of model names') self.logs = {} for item in visualize_list: self.logs[item] = [0.0] self.step = 1 if TENSORBOARD_LOGGING == 1: self._build_tensorboard(log_dir, writer) if VISDOM_LOGGING == 1: self._build_visdom(visdom_port)
[docs] def log_tensorboard(self, name): r"""Tensorboard logging function. This function logs the values of the individual gradients. Args: name (str): Name of the model whose gradients are to be logged. """ self.writer.add_scalar('Gradients/{}'.format(name), self.logs[name][len(self.logs[name]) - 1], self.step)
[docs] def log_console(self, name): r"""Console logging function. This function logs the mean gradients. Args: name (str): Name of the model whose gradients are to be logged. """ print('{} Gradients : {}'.format(name, self.logs[name][len(self.logs[name]) - 1]))
[docs] def log_visdom(self, name): r"""Visdom logging function. This function logs the values of the individual gradients. Args: name (str): Name of the model whose gradients are to be logged. """ self.vis.line([self.logs[name][len(self.logs[name]) - 1]], [self.step], win=name, update="append", opts=dict(title=name, xlabel="Time Step", ylabel="Gradient"))
[docs] def update_grads(self, name, model, eps=1e-5): r"""Updates the gradient logs. Args: name (str): Name of the model. model (torch.nn.Module): Either a ``torchgan.models.Generator`` or a ``torchgan.models.Discriminator`` or their subclass. eps (float, optional): Tolerance value. """ gradsum = 0.0 for p in model.parameters(): if p.grad is not None: gradsum += torch.sum(p.grad ** 2).clone().item() if gradsum > eps: self.logs[name][len(self.logs[name]) - 1] += gradsum model.zero_grad()
[docs] def report_end_epoch(self): r"""Prints to the console at the end of the epoch. """ for key, val in self.logs.items(): print('{} Mean Gradients : {}'.format(key, sum(val) / len(val)))
def __call__(self, trainer, **kwargs): for name in trainer.model_names: super(GradientVisualize, self).__call__(name, **kwargs) self.logs[name].append(0.0)
[docs]class ImageVisualize(Visualize): r"""This class provides the Logging for the Images. Args: trainer (torchgan.trainer.Trainer): The base trainer used for training. visdom_port (int, optional): Port to log using ``visdom``. The visdom server needs to be manually started at this port else an error will be thrown and the code will crash. This is ignored if ``VISDOM_LOGGING`` is ``0``. log_dir (str, optional): Directory where TensorboardX should store the logs. This is ignored if ``TENSORBOARD_LOGGING`` is ``0``. writer (tensorboardX.SummaryWriter, optonal): Send a `SummaryWriter` if you don't want to start a new SummaryWriter. test_noise (torch.Tensor, optional): If provided then it will be used as the noise for image sampling. nrow (int, optional): Number of rows in which the image is to be stored. """ def __init__(self, trainer, visdom_port=8097, log_dir=None, writer=None, test_noise=None, nrow=8): super(ImageVisualize, self).__init__([], visdom_port=visdom_port, log_dir=log_dir, writer=writer) self.test_noise = [] for model in trainer.model_names: if isinstance(getattr(trainer, model), Generator): self.test_noise.append(getattr(trainer, model).sampler(trainer.sample_size, trainer.device) if test_noise is None else test_noise) self.step = 1 self.nrow = nrow
[docs] def log_tensorboard(self, trainer, image, model): r"""Logs a generated image in tensorboard at the end of an epoch. Args: trainer (torchgan.trainer.Trainer): The base trainer used for training. image (Image): The generated image. model (str): The name of the model which generated the ``image``. """ self.writer.add_image("Generated Samples/{}".format(model), image, self.step)
[docs] def log_console(self, trainer, image, model): r"""Saves a generated image at the end of an epoch. The path where the image is being stored is controlled by the ``trainer``. Args: trainer (torchgan.trainer.Trainer): The base trainer used for training. image (Image): The generated image. model (str): The name of the model which generated the ``image``. """ save_path = "{}/epoch{}_{}.png".format(trainer.recon, self.step, model) print("Generating and Saving Images to {}".format(save_path)) torchvision.utils.save_image(image, save_path, nrow=self.nrow)
[docs] def log_visdom(self, trainer, image, model): r"""Logs a generated image in visdom at the end of an epoch. Args: trainer (torchgan.trainer.Trainer): The base trainer used for training. image (Image): The generated image. model (str): The name of the model which generated the ``image``. """ self.vis.image(image, opts=dict(caption="Generated Samples/{}".format(model)))
def __call__(self, trainer, **kwargs): pos = 0 for model in trainer.model_names: if isinstance(getattr(trainer, model), Generator): generator = getattr(trainer, model) with torch.no_grad(): image = generator(*self.test_noise[pos]) image = torchvision.utils.make_grid(image) super(ImageVisualize, self).__call__(trainer, image, model, **kwargs) self.step -= 1 pos = pos + 1 self.step += 1 if pos > 0 else 0