Skip to content

Instantly share code, notes, and snippets.

@rexlow
Created February 25, 2021 11:48
Show Gist options
  • Save rexlow/b4bce3c91db3f3c036b1f339f8085c79 to your computer and use it in GitHub Desktop.
Save rexlow/b4bce3c91db3f3c036b1f339f8085c79 to your computer and use it in GitHub Desktop.
A quick and dirty implementation of GAN network in PyTorch that approximates a gaussian distribution
#!/usr/bin/python3
# Sample GAN implementation for learning purposes only
# The network will model 2 different dataset of different distribution and
# train Generator to approximate the Distributor's distribution better
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
# load ui libs
seaborn_available, matplotlib_available = True, True
try:
import seaborn as sns
except ImportError:
seaborn_available = False
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('tkagg') # fix possible segfault in macos
except ImportError:
matplotlib_available = False
# cuda
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# target distribution
mean, stddev = 4, 1.25
# hyperparameters
g_in_size = 1
g_hidden_size = 10 # generator complexity
g_out_size = 1
d_in_size = 500
d_hidden_size = 20 # discriminator complexity
d_out_size = 1 # binary output, either true or false
minibatch_size = d_in_size
g_learning_rate = 1e-3
d_learning_rate = 1e-3
sgd_momentum = 0.9
num_epochs = 10000
num_steps = 20
print_interval = 100
# activation functions
g_act = torch.tanh
d_act = torch.sigmoid
# network
class Generator(nn.Module):
def __init__(self, in_features, hidden_features, out_features, f):
super(Generator, self).__init__()
self.m1 = nn.Linear(in_features, hidden_features)
self.m2 = nn.Linear(hidden_features, hidden_features)
self.m3 = nn.Linear(hidden_features, out_features)
self.f = f
def forward(self, x):
x = self.f(self.m1(x))
x = self.f(self.m2(x))
return self.m3(x)
class Discriminator(nn.Module):
def __init__(self, in_features, hidden_features, out_features, f):
super(Discriminator, self).__init__()
self.m1 = nn.Linear(in_features, hidden_features)
self.m2 = nn.Linear(hidden_features, hidden_features)
self.m3 = nn.Linear(hidden_features, out_features)
self.f = f
def forward(self, x):
x = self.f(self.m1(x))
x = self.f(self.m2(x))
return self.f(self.m3(x))
def get_gaussian_distribution(mu, sigma):
return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))
def get_uniform_distribution():
return lambda m, n: torch.rand(m, n)
def tensor_to_list(v):
return v.data.storage().tolist()
def stats(d):
return [round(np.mean(d), 4), round(np.std(d), 4)]
def histplot(data, label="", bins=25):
if matplotlib_available:
ax = sns.histplot(data, label=label, bins=bins)
legend = plt.legend()
plt.show()
else:
print("Matplotlib not available")
if __name__ == "__main__":
D_sampler = get_gaussian_distribution(mean, stddev) # (1, 500)
G_sampler = get_uniform_distribution() # (500, 1)
# Uncomment to visualize input distribution
# if seaborn_available:
# histplot(D_sampler(d_in_size).numpy()[0], label="Exampel Discriminator Input")
# histplot(G_sampler(minibatch_size, g_in_size).numpy(), label="Exampel Generator Input")
G = Generator(g_in_size, g_hidden_size, g_out_size, g_act)
D = Discriminator(d_in_size, d_hidden_size, d_out_size, d_act)
G = G.to(device)
D = D.to(device)
# error rates
D_real_err_rate, D_fake_err_rate, G_err_rate = 0, 0, 0
# loss & optimizers
criterion = nn.BCELoss()
G_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)
D_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
for epoch in range(num_epochs):
for d_step in range(num_steps):
D.zero_grad()
# train real input on Discriminator
D_real_x = D_sampler(d_in_size).to(device)
D_real_y = D(D_real_x)
D_real_err = criterion(D_real_y, torch.ones([1, 1], device=device)) # 1 - true
D_real_err.backward()
# train fake input on Discriminator
D_fake_input = G_sampler(d_in_size, g_in_size).to(device)
D_fake_x = G(D_fake_input).detach()
D_fake_y = D(D_fake_x.t())
D_fake_err = criterion(D_fake_y, torch.zeros([1, 1], device=device)) # 0 - fake
D_fake_err.backward()
D_optimizer.step()
# update discriminator error rates
D_real_err_rate = tensor_to_list(D_real_err)[0]
D_fake_err_rate = tensor_to_list(D_fake_err)[0]
# train Generator with the Discriminator's responses, but do not update Discriminator
for g_step in range(num_steps):
G.zero_grad()
G_input = G_sampler(minibatch_size, g_in_size).to(device)
G_fake_x = G(G_input)
G_fake_y = D(G_fake_x.t())
G_err = criterion(G_fake_y, torch.ones([1, 1], device=device))
G_err.backward()
G_optimizer.step()
G_err_rate = tensor_to_list(G_err)[0]
if epoch % print_interval == 0:
print(f"Epoch: {epoch} Loss (DRE, DFE, GE): {round(D_real_err_rate, 4)} {round(D_fake_err_rate, 4)} {round(G_err_rate, 4)} Dist (Real | Fake): {stats(tensor_to_list(D_real_x))} {stats(tensor_to_list(D_fake_x))}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment