In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from IPython import display
from torch import autograd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
batchsize = 100

# MNIST Dataset
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=True)

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Critic(nn.Module):
    def __init__(self, d_input_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return self.fc4(x)

In [None]:
def calc_gradient_penalty(netD, real_data, fake_data, LAMBDA = 10.0):
    alpha = torch.rand(batchsize, 1, device=device)
    alpha = alpha.expand(batchsize,
                         int(real_data.nelement()/batchsize)).contiguous().view(batchsize,
                         real_data.size(1))
    
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    
    interpolates = interpolates.clone().detach().requires_grad_(True) 
    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                            grad_outputs=torch.ones(disc_interpolates.size(),device=device),
                            create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty




# build network
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
C = Critic(mnist_dim).to(device)

criterion = nn.BCELoss() 

# optimizer
lr = 0.0001 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
C_optimizer = optim.Adam(C.parameters(), lr = lr)

In [None]:
%matplotlib inline

n_epoch = 200
for epoch in range(1, n_epoch+1):           

    for batch_idx, (x, _) in enumerate(train_loader):
        
        for i in range(0,5):
            C.zero_grad()

            # train Critic on real
            x_real = x.view(-1, mnist_dim)
            x_real = x_real.to(device)

            C_output = C(x_real)
            C_real_loss = C_output

            # train Critic on facke
            z = torch.randn(batchsize, z_dim).to(device)
            x_fake = G(z)

            C_output = C(x_fake)
            C_fake_loss = C_output

            gp = calc_gradient_penalty(C, x_real, x_fake)
            
            C_loss = torch.mean(C_real_loss)*(-1.0) + torch.mean(C_fake_loss)*(1.0) + torch.mean(gp)
            C_loss.backward()
            C_optimizer.step()    
        
        
        G.zero_grad()

        z = torch.randn(batchsize, z_dim).to(device)

        G_output = G(z)
        C_output = C(G_output)
        G_loss = torch.mean(C_output)*(-1.0)

        # gradient backprop & optimize ONLY G's parameters
        G_loss.backward()
        G_optimizer.step()

        if (batch_idx+1)%100 == 0:
            z = torch.randn(1, z_dim).to(device)
            G_output = G(z).view(28,28)

            fig = plt.figure()
            ax = fig.add_subplot(111)
            ax.imshow(G_output.detach().cpu().numpy(), cmap='gray')
            display.display(fig)  

            print('Iterations: {}'.format(batch_idx + 1 + (epoch-1)*600))