我们这里做个demo,就不直接生成图片了,而是事先准备好一些“点”,以这些“点”来代替图片。我们训练一个GAN,看看训练出的这个GAN的Generator能不能拟合我们实现准备好的“点”的分布。我们这里准备一个8-Gaussian Mixture Distribution,但我们假装并不知道这些“点”的分布(因为我们并不知道高维空间中的图片符合什么分布),让GAN来学习出他们的分布。
先定两个变量:
h_dim = 400 batchsz = 512生成数据的代码如下,这些“点”就相当于real image:
def data_generator(): # 8-gaussian mixture model scale = 2. centers = [ (1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2)) ] centers = [(scale * x, scale * y) for x, y in centers] while True: dataset = [] for i in range(batchsz): point = np.random.randn(2) * 0.02 center = random.choice(centers) point[0] += center[0] point[1] += center[1] dataset.append(point) dataset = np.array(dataset, dtype='float32') dataset /= 1.414 # stdev yield dataset这里我们就随便搞几个层来搭一个Generator和Discriminator:
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.net = nn.Sequential( # 这个2也可以换成变的,只不过是你noise特征的维度 nn.Linear(2, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), # “点”是二维的,所以输出必须是2维 nn.Linear(h_dim, 2), ) def forward(self, z): output = self.net(z) return output class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.net = nn.Sequential( nn.Linear(2, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), nn.Linear(h_dim, h_dim), nn.ReLU(True), nn.Linear(h_dim, 1), nn.Sigmoid() ) def forward(self, x): output = self.net(x) return output.view(-1)首先我们得到数据生成器:
data_iter = data_generator()根据上篇博客(Discriminator多训练,Generator少训练),我们训练五次Discriminator,一次Generator。下面看代码:
G = Generator().cuda() D = Discriminator().cuda() optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9)) optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9)) for epoch in range(50000): # train Discriminator for _ in range(5): ##############for real data############### # 得到real data x = next(data_iter) xr = torch.from_numpy(x).cuda() # 打分 predr = D(xr) # 给真实数据高分 lossr = - (predr.mean()) ##############for fake data############### # noise z = torch.randn(batchsz, 2).cuda() # 生成的数据, 我们这时训练的是Discriminator不需要更新Generator的梯度 xf = G(z).detach() # 打分 predf = (D(xf)) # 给生成的数据低分 lossf = (predf.mean()) ##############for Discriminator############### loss_D = lossr + lossf ################update parameter################# optim_D.zero_grad() loss_D.backward() optim_D.step() # train Generator z = torch.randn(batchsz, 2).cuda() xf = G(z) predf = D(xf) # 让Discriminator给fake数据打高分 loss_G = - (predf.mean()) optim_G.zero_grad() loss_G.backward() optim_G.step() if epoch % 100 == 0: print(loss_D.item(), loss_G.item())至此,最naive的GAN的代码demo就全部完成了,下一篇讲讲WGAN解决的问题和WGAN的代码。
