!wget https://raw.githubusercontent.com/diegoalejogm/gans/master/utils.py
--2024-02-08 19:24:45-- https://raw.githubusercontent.com/diegoalejogm/gans/master/utils.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 4866 (4.8K) [text/plain] Saving to: 'utils.py' utils.py 100%[===================>] 4.75K --.-KB/s in 0s 2024-02-08 19:24:45 (27.2 MB/s) - 'utils.py' saved [4866/4866]
!pip install tensorboardX
Requirement already satisfied: tensorboardX in /opt/conda/lib/python3.10/site-packages (2.6.2.2) Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from tensorboardX) (1.24.4) Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from tensorboardX) (21.3) Requirement already satisfied: protobuf>=3.20 in /opt/conda/lib/python3.10/site-packages (from tensorboardX) (3.20.3) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->tensorboardX) (3.1.1)
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from utils import Logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cpu')
## loading mnist data
mnist = datasets.MNIST('./data', download=True, train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
]))
mnist.data.to(device) ;
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 216176490.82it/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 39038251.31it/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 60019712.36it/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 11166781.22it/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
data_loader = DataLoader(mnist, batch_size=100, shuffle=True)
num_batches = len(data_loader)
class GeneratorNet(torch.nn.Module):
"""
A three hidden-layer generative neural network
"""
def __init__(self):
super(GeneratorNet, self).__init__()
n_features = 100
n_out = 784
self.hidden0 = nn.Sequential(
nn.Linear(n_features, 256),
nn.LeakyReLU(0.2)
)
self.hidden1 = nn.Sequential(
nn.Linear(256, 512),
nn.LeakyReLU(0.2)
)
self.hidden2 = nn.Sequential(
nn.Linear(512, 1024),
nn.LeakyReLU(0.2)
)
self.out = nn.Sequential(
nn.Linear(1024, n_out),
nn.Tanh()
)
def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x
generator = GeneratorNet().to(device)
class DiscriminatorNet(torch.nn.Module):
"""a three hidden layer descriminative neural network"""
def __init__(self):
super(DiscriminatorNet, self).__init__()
n_features = 784
n_out = 1
self.hidden0 = nn.Sequential(
nn.Linear(n_features, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden1 = nn.Sequential(
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden2 = nn.Sequential(
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.out = nn.Sequential(
torch.nn.Linear(256, n_out),
torch.nn.Sigmoid()
)
def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x
discriminator = DiscriminatorNet().to(device)
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
def train_generator(optimizer, fake_data):
N = fake_data.size(0) # Reset gradients
optimizer.zero_grad() # Sample noise and generate fake data
prediction = discriminator(fake_data) # Calculate error and backpropagate
error = loss(prediction, torch.ones(N, 1).to(device))
error.backward() # Update weights with gradients
optimizer.step() # Return error
return error
def train_discriminator(optimizer, real_data, fake_data):
N = real_data.size(0)
# Reset gradients
optimizer.zero_grad()
# 1.1 Train on Real Data
prediction_real = discriminator(real_data)
# Calculate error and backpropagate
error_real = loss(prediction_real, torch.ones(N, 1).to(device))
error_real.backward()
# 1.2 Train on Fake Data
prediction_fake = discriminator(fake_data)
# Calculate error and backpropagate
error_fake = loss(prediction_fake, torch.zeros(N, 1).to(device))
error_fake.backward()
# 1.3 Update weights with gradients
optimizer.step()
# Return error and predictions for real and fake inputs
return error_real + error_fake, prediction_real, prediction_fake
# Create logger instance
logger = Logger(model_name='VGAN', data_name='MNIST')# Total number of epochs to train
num_epochs = 200
num_test_samples = 16
for epoch in range(num_epochs):
for n_batch, (real_batch,_) in enumerate(data_loader):
N = real_batch.size(0)
# 1. Train Discriminator
real_data = real_batch.view(real_batch.size(0), 784)
# Generate fake data and detach
# (so gradients are not calculated for generator)
fake_data = generator(torch.randn(N, 100).to(device)).detach()
# Train D
d_error, d_pred_real, d_pred_fake = \
train_discriminator(d_optimizer, real_data, fake_data)
# 2. Train Generator
# Generate fake data
fake_data = generator(torch.randn(N, 100).to(device)) # Train G
g_error = train_generator(g_optimizer, fake_data) # Log batch error
logger.log(d_error, g_error, epoch, n_batch, num_batches) # Display Progress every few batches
if (n_batch) % 100 == 0:
gen_output = generator(torch.randn(num_test_samples, 100).to(device))
test_images = gen_output.view(gen_output.size(0), 1, 28, 28)
test_images = test_images.data
logger.log_images(
test_images.cpu(), num_test_samples,
epoch, n_batch, num_batches
);
# Display status Logs
logger.display_status(
epoch, num_epochs, n_batch, num_batches,
d_error, g_error, d_pred_real, d_pred_fake
)
from PIL import Image
import os
# Directory containing the images
directory = 'data/images/VGAN/MNIST'
# Initialize a list to hold the image paths
image_paths = []
# Iterate over epochs and batches
for epoch in range(201):
for batch in [0, 100, 200, 300, 400, 500]:
# Construct the filename based on the epoch and batch
filename = f'_epoch_{epoch}_batch_{batch}.png'
# Construct the full path to the image
image_path = os.path.join(directory, filename)
# Check if the file exists
if os.path.exists(image_path):
image_paths.append(image_path)
# Initialize an empty list to hold the opened images
images = []
# Open each image, append it to the list, and then immediately close it
for img_path in image_paths:
with Image.open(img_path) as img:
images.append(img.copy()) # Make a copy to avoid keeping the file open
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(i, animated=True) ] for i in images[::50]]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
plt.close(fig)
HTML(ani.to_jshtml())
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
generator_err = mpimg.imread("generator_error.png")
discriminator_err = mpimg.imread("discriminator_error.png")
fig, axes = plt.subplots(1, 2, figsize=(10, 5)) # Adjust figsize as needed
# Plot the first image
axes[0].imshow(generator_err)
axes[0].set_title('generator')
axes[0].axis('off') # Remove axis
# Plot the second image
axes[1].imshow(discriminator_err)
axes[1].set_title('discriminator')
axes[1].axis('off') # Remove axis
# Adjust layout
plt.tight_layout()
# Show the plot
plt.show()