Skip to content

Commit 4383899

Browse files
authored
Merge pull request #1351 from xiezl/patch-3
Add the generative model for the peft example
2 parents 33ad565 + 00a836b commit 4383899

1 file changed

Lines changed: 175 additions & 0 deletions

File tree

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import device
21+
from singa import opt
22+
from singa import tensor
23+
24+
import argparse
25+
import matplotlib.pyplot as plt
26+
import numpy as np
27+
import os
28+
from model import gan_mlp
29+
from utils import load_data
30+
from utils import print_log
31+
32+
33+
class VANILLA():
34+
35+
def __init__(self,
36+
dev,
37+
rows=28,
38+
cols=28,
39+
channels=1,
40+
noise_size=100,
41+
hidden_size=128,
42+
batch=128,
43+
interval=1000,
44+
learning_rate=0.001,
45+
iterations=1000000,
46+
dataset_filepath='mnist.pkl.gz',
47+
file_dir='vanilla_images/'):
48+
self.dev = dev
49+
self.rows = rows
50+
self.cols = cols
51+
self.channels = channels
52+
self.feature_size = self.rows * self.cols * self.channels
53+
self.noise_size = noise_size
54+
self.hidden_size = hidden_size
55+
self.batch = batch
56+
self.batch_size = self.batch // 2
57+
self.interval = interval
58+
self.learning_rate = learning_rate
59+
self.iterations = iterations
60+
self.dataset_filepath = dataset_filepath
61+
self.file_dir = file_dir
62+
self.model = gan_mlp.create_model(noise_size=self.noise_size,
63+
feature_size=self.feature_size,
64+
hidden_size=self.hidden_size)
65+
66+
def train(self):
67+
train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
68+
dev = device.create_cuda_gpu_on(0)
69+
dev.SetRandSeed(0)
70+
np.random.seed(0)
71+
72+
# sgd = opt.SGD(lr=self.learning_rate, momentum=0.9, weight_decay=1e-5)
73+
sgd = opt.Adam(lr=self.learning_rate)
74+
75+
noise = tensor.Tensor((self.batch_size, self.noise_size), dev,
76+
tensor.float32)
77+
real_images = tensor.Tensor((self.batch_size, self.feature_size), dev,
78+
tensor.float32)
79+
real_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
80+
fake_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
81+
82+
# attached model to graph
83+
self.model.set_optimizer(sgd)
84+
self.model.compile([noise],
85+
is_train=True,
86+
use_graph=False,
87+
sequential=True)
88+
89+
real_labels.set_value(1.0)
90+
fake_labels.set_value(0.0)
91+
92+
for iteration in range(self.iterations):
93+
idx = np.random.randint(0, train_data.shape[0], self.batch_size)
94+
real_images.copy_from_numpy(train_data[idx])
95+
96+
self.model.train()
97+
98+
# Training the Discriminative Net
99+
_, d_loss_real = self.model.train_one_batch_dis(
100+
real_images, real_labels)
101+
102+
noise.uniform(-1, 1)
103+
fake_images = self.model.forward_gen(noise)
104+
_, d_loss_fake = self.model.train_one_batch_dis(
105+
fake_images, fake_labels)
106+
107+
d_loss = tensor.to_numpy(d_loss_real)[0] + tensor.to_numpy(
108+
d_loss_fake)[0]
109+
110+
# Training the Generative Net
111+
noise.uniform(-1, 1)
112+
_, g_loss_tensor = self.model.train_one_batch(
113+
noise, real_labels)
114+
115+
g_loss = tensor.to_numpy(g_loss_tensor)[0]
116+
117+
if iteration % self.interval == 0:
118+
self.model.eval()
119+
self.save_image(iteration)
120+
print_log(' The {} iteration, G_LOSS: {}, D_LOSS: {}'.format(
121+
iteration, g_loss, d_loss))
122+
123+
def save_image(self, iteration):
124+
demo_row = 5
125+
demo_col = 5
126+
if not hasattr(self, "demo_noise"):
127+
self.demo_noise = tensor.Tensor(
128+
(demo_col * demo_row, self.noise_size), dev, tensor.float32)
129+
self.demo_noise.uniform(-1, 1)
130+
gen_imgs = self.model.forward_gen(self.demo_noise)
131+
gen_imgs = tensor.to_numpy(gen_imgs)
132+
show_imgs = np.reshape(
133+
gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
134+
fig, axs = plt.subplots(demo_row, demo_col)
135+
cnt = 0
136+
for r in range(demo_row):
137+
for c in range(demo_col):
138+
axs[r, c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
139+
axs[r, c].axis('off')
140+
cnt += 1
141+
fig.savefig("{}{}.png".format(self.file_dir, iteration))
142+
plt.close()
143+
144+
145+
if __name__ == '__main__':
146+
parser = argparse.ArgumentParser(description='Train GAN over MNIST')
147+
parser.add_argument('filepath', type=str, help='the dataset path')
148+
parser.add_argument('--use_gpu', action='store_true')
149+
args = parser.parse_args()
150+
151+
if args.use_gpu:
152+
print('Using GPU')
153+
dev = device.create_cuda_gpu()
154+
else:
155+
print('Using CPU')
156+
dev = device.get_default_device()
157+
158+
if not os.path.exists('vanilla_images/'):
159+
os.makedirs('vanilla_images/')
160+
161+
rows = 28
162+
cols = 28
163+
channels = 1
164+
noise_size = 100
165+
hidden_size = 128
166+
batch = 128
167+
interval = 1000
168+
learning_rate = 0.0005
169+
iterations = 1000000
170+
dataset_filepath = 'mnist.pkl.gz'
171+
file_dir = 'vanilla_images/'
172+
vanilla = VANILLA(dev, rows, cols, channels, noise_size, hidden_size, batch,
173+
interval, learning_rate, iterations, dataset_filepath,
174+
file_dir)
175+
vanilla.train()

0 commit comments

Comments
 (0)