In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from IPython import display

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
batchsize = 100

# MNIST Dataset
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=True)

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [None]:

# build network
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

criterion = nn.BCELoss() 

# optimizer
lr = 0.0001 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [None]:
%matplotlib inline

n_epoch = 200
for epoch in range(1, n_epoch+1):           

    for batch_idx, (x, _) in enumerate(train_loader):
        D.zero_grad()

        # train discriminator on real
        x_real, y_real = x.view(-1, mnist_dim), torch.ones(batchsize, 1)
        x_real, y_real = x_real.to(device), y_real.to(device)

        D_output = D(x_real)
        D_real_loss = criterion(D_output, y_real)

        # train discriminator on facke
        z = torch.randn(batchsize, z_dim).to(device)
        x_fake, y_fake = G(z), torch.zeros(batchsize, 1).to(device)

        D_output = D(x_fake)
        D_fake_loss = criterion(D_output, y_fake)

        
        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optimizer.step()    
        
        
        
        
        G.zero_grad()

        z = torch.randn(batchsize, z_dim).to(device)
        y = torch.ones(batchsize, 1).to(device)

        G_output = G(z)
        D_output = D(G_output)
        G_loss = criterion(D_output, y)

        # gradient backprop & optimize ONLY G's parameters
        G_loss.backward()
        G_optimizer.step()
        
        if (batch_idx+1)%100 == 0:
            z = torch.randn(1, z_dim).to(device)
            G_output = G(z).view(28,28)

            fig = plt.figure()
            ax = fig.add_subplot(111)
            ax.imshow(G_output.detach().cpu().numpy(), cmap='gray')
            display.display(fig)  

            print('Iterations: {}'.format(batch_idx + 1 + (epoch-1)*600))