Train a GAN

Generative Adversarial Network (GAN) can be used to generate data such as images and videos. This tutorial will show you how to train a GAN with MMEngine step by step!

It will be divided into the following steps:

Building a DataLoader

Building a Dataset

First, we will build a dataset class MNISTDataset for the MNIST dataset, inheriting from the base dataset class BaseDataset, and overwrite the load_data_list function of the base dataset class to ensure that the return value is a list[dict], where each dict represents a data sample. More details about using datasets in MMEngine, refer to the Dataset tutorial.

import numpy as np
from mmcv.transforms import to_tensor
from import random_split
from torchvision.datasets import MNIST

from mmengine.dataset import BaseDataset

class MNISTDataset(BaseDataset):

    def __init__(self, data_root, pipeline, test_mode=False):
        # Download MNIST Dataset
        if test_mode:
            mnist_full = MNIST(data_root, train=True, download=True)
            self.mnist_dataset, _ = random_split(mnist_full, [55000, 5000])
            self.mnist_dataset = MNIST(data_root, train=False, download=True)

            data_root=data_root, pipeline=pipeline, test_mode=test_mode)

    def totensor(img):
        if len(img.shape) < 3:
            img = np.expand_dims(img, -1)
        img = np.ascontiguousarray(img.transpose(2, 0, 1))
        return to_tensor(img)

    def load_data_list(self):
        return [
            dict(inputs=self.totensor(np.array(x[0]))) for x in self.mnist_dataset

dataset = MNISTDataset("./data", [])

Use the function build_dataloader in Runner to build the dataloader.

import os
import torch
from mmengine.runner import Runner

NUM_WORKERS = int(os.cpu_count() / 2)
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

train_dataloader = dict(
    sampler=dict(type='DefaultSampler', shuffle=True),
train_dataloader = Runner.build_dataloader(train_dataloader)

Build a Generator Network and a Discriminator Network

The following code builds and instantiates a Generator and a Discriminator.

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_size, img_shape):
        self.img_shape = img_shape
        self.noise_size = noise_size

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(noise_size, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(,

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img
class Discriminator(nn.Module):
    def __init__(self, img_shape):

        self.model = nn.Sequential(
            nn.Linear(int(, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
generator = Generator(100, (1, 28, 28))
discriminator = Discriminator((1, 28, 28))

Build a Generative Adversarial Network Model

In MMEngine, we use ImgDataPreprocessor to normalize the data and convert the color channels.

from mmengine.model import ImgDataPreprocessor

data_preprocessor = ImgDataPreprocessor(mean=([127.5]), std=([127.5]))

The following code implements the basic algorithm of GAN. To implement the algorithm using MMEngine, you need to inherit from the BaseModel and implement the training process in the train_step. GAN requires alternating training of the generator and discriminator, which are implemented by train_discriminator and train_generator and implement disc_loss and gen_loss to calculate the discriminator loss function and generator loss function. More details about BaseModel, refer to Model tutorial.

import torch.nn.functional as F
from mmengine.model import BaseModel

class GAN(BaseModel):

    def __init__(self, generator, discriminator, noise_size,
        assert generator.noise_size == noise_size
        self.generator = generator
        self.discriminator = discriminator
        self.noise_size = noise_size

    def train_step(self, data, optim_wrapper):
        # Acquiring and preprocessing data
        inputs_dict = self.data_preprocessor(data, True)
        # Training the discriminator
        disc_optimizer_wrapper = optim_wrapper['discriminator']
        with disc_optimizer_wrapper.optim_context(self.discriminator):
            log_vars = self.train_discriminator(inputs_dict,

        # Training the generator
        set_requires_grad(self.discriminator, False)
        gen_optimizer_wrapper = optim_wrapper['generator']
        with gen_optimizer_wrapper.optim_context(self.generator):
            log_vars_gen = self.train_generator(inputs_dict,

        set_requires_grad(self.discriminator, True)

        return log_vars

    def forward(self, batch_inputs, data_samples=None, mode=None):
        return self.generator(batch_inputs)

    def disc_loss(self, disc_pred_fake, disc_pred_real):
        losses_dict = dict()
        losses_dict['loss_disc_fake'] = F.binary_cross_entropy(
            disc_pred_fake, 0. * torch.ones_like(disc_pred_fake))
        losses_dict['loss_disc_real'] = F.binary_cross_entropy(
            disc_pred_real, 1. * torch.ones_like(disc_pred_real))

        loss, log_var = self.parse_losses(losses_dict)
        return loss, log_var

    def gen_loss(self, disc_pred_fake):
        losses_dict = dict()
        losses_dict['loss_gen'] = F.binary_cross_entropy(
            disc_pred_fake, 1. * torch.ones_like(disc_pred_fake))
        loss, log_var = self.parse_losses(losses_dict)
        return loss, log_var

    def train_discriminator(self, inputs, optimizer_wrapper):
        real_imgs = inputs['inputs']
        z = torch.randn(
            (real_imgs.shape[0], self.noise_size)).type_as(real_imgs)
        with torch.no_grad():
            fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        disc_pred_real = self.discriminator(real_imgs)

        parsed_losses, log_vars = self.disc_loss(disc_pred_fake,
        return log_vars

    def train_generator(self, inputs, optimizer_wrapper):
        real_imgs = inputs['inputs']
        z = torch.randn(real_imgs.shape[0], self.noise_size).type_as(real_imgs)

        fake_imgs = self.generator(z)

        disc_pred_fake = self.discriminator(fake_imgs)
        parsed_loss, log_vars = self.gen_loss(disc_pred_fake)

        return log_vars

The function, set_requires_grad, is used to lock the weights of the discriminator when training the generator.

def set_requires_grad(nets, requires_grad=False):
    """Set requires_grad for all the networks.

        nets (nn.Module | list[nn.Module]): A list of networks or a single
        requires_grad (bool): Whether the networks require gradients or not.
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

model = GAN(generator, discriminator, 100, data_preprocessor)

Building an Optimizer

MMEngine uses OptimWrapper to wrap optimizers. For multiple optimizers, we use OptimWrapperDict to further wrap OptimWrapper. More details about optimizers, refer to the Optimizer tutorial.

from mmengine.optim import OptimWrapper, OptimWrapperDict

opt_g = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_g_wrapper = OptimWrapper(opt_g)

opt_d = torch.optim.Adam(
    discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_d_wrapper = OptimWrapper(opt_d)

opt_wrapper_dict = OptimWrapperDict(
    generator=opt_g_wrapper, discriminator=opt_d_wrapper)

Training with Runner

The following code demonstrates how to use Runner for model training. More details about Runner, please refer to the Runner tutorial.

train_cfg = dict(by_epoch=True, max_epochs=220)
runner = Runner(

Till now, we have completed an example of training a GAN. The following code can be used to view the results generated by the GAN we just trained.

GAN generate an image

If you want to learn more about using MMEngine to implement GAN and generative models, we highly recommend you try the generative framework MMGeneration based on MMEngine.

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.