Generative Adversarial Nets_GAN_CODE
2021. 3. 2. 17:38ㆍData Science/05_Research paper
반응형
Generative Adversarial Nets_GAN_(CODE)
: 2014년 공개된 GAN(Generative Adversarial Nets) 기반 Conv를 사용한 DCGAN 구현
: Python Version: 3.6.8, tensorflow Version: 2.4.0 사용
GAN(generative Adversarial Nets_Overview
1. Generator
: 생성모델(G)의 경우 4개의 Conv layers 를 쌓음
: Conv로 upsampling을 하기 위해서 Conv2DTranspose, stirde=2를 사용
Conv2DTranspose Tensor
www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose
# conv_transpose
def conv_transpose_block(inputs, num_filters, kernel_size, init, strides, treatment=True):
x = Conv2DTranspose(
filters=num_filters,
kernel_size=kernel_size,
kernel_initializer=init,
padding="same",
strides=strides,
use_bias=False,
)(inputs)
if treatment:
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
return x
# generator
def build_generator(g_h, g_w, g_c, k_size, strides, latent_dim, mean, std, alpha):
c_list = [512, 256, 128, 3]
times = len(c_list)
noise = Input(shape=(latent_dim,), name="generator_noise_input")
x = Dense(g_h * g_w * g_c, use_bias=False)(noise)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
x = Reshape((g_h, g_w, g_c))(x)
for i in range(times):
x = conv_transpose_block(
inputs = x,
num_filters = c_list[i],
kernel_size = k_size,
init = norm_init(mean=mean, std=std),
strides = strides,
treatment=True,
)
generator_output = Activation("tanh")(x)
return Model(noise, generator_output, name='generator')
2. discriminator
: 판별모델(D)의 경우 생성모델(G)과 같이 4개의 Conv layers 를 쌓음
: Conv로 Downsampling을 하기 위해서 Conv2D, stirde=2를 사용
# conv_block
def conv_block(inputs, num_filters, kernel_size, init, strides, alpha, treatment=True):
x = Conv2D(
filters=num_filters,
kernel_size=kernel_size,
kernel_initializer=init,
padding="same",
strides=strides,
)(inputs)
if treatment:
x = BatchNormalization()(x)
x = LeakyReLU(alpha=alpha)(x)
return x
# discriminator
def build_discriminator(d_h, d_w, d_c, k_size, strides, mean, std, alpha):
c_list = [64, 128, 256, 512]
times = len(c_list)
input_image = Input(shape=(d_h, d_w, d_c))
x = input_image
for i in range(times):
x = conv_block(
inputs = x,
num_filters = c_list[i],
kernel_size = k_size,
init = norm_init(mean=mean, std=std),
strides = strides,
alpha = alpha,
treatment = True,
)
x = Flatten()(x)
discriminator_output = Dense(1)(x)
return Model(input_image, discriminator_output, name="discriminator")
3. Generative Adversarial Nets(GAN)
class gan(Model):
def __init__(self, discriminator, generator, latent_dim):
super(gan, self).__init__(self)
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
super(gan, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.d_loss_fn = d_loss_fn
self.g_loss_fn = g_loss_fn
#@tf.function
def train_step(self, images):
batch_size = tf.shape(images).numpy()[0]
noise = tf.random.normal([batch_size, self.latent_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = self.generator(noise, training=True)
real_output = self.discriminator(images, training=True)
fake_output = self.discriminator(generated_images, training=True)
gen_loss = self.g_loss_fn(tf.ones_like(fake_output), fake_output)
real_loss = self.d_loss_fn(tf.ones_like(real_output), real_output)
fake_loss = self.d_loss_fn(tf.zeros_like(fake_output), fake_output)
disc_loss = real_loss + fake_loss
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
return print(
"real_loss :",real_loss.numpy(),
"fake_loss :",fake_loss.numpy(),
"disc_loss :",disc_loss.numpy(),
"gen_loss :",gen_loss.numpy(),
), [disc_loss.numpy(), gen_loss.numpy()]
반응형
'Data Science > 05_Research paper' 카테고리의 다른 글
LTSF-Linear(DLinear, NLinear) 논문 리뷰/구현(Are Transformers Effective for Time Series Forecasting?) (28) | 2023.01.01 |
---|---|
TabNet 논문 리뷰(Attentive Interpretable Tabular Learning) (0) | 2022.01.29 |
Generative Adversarial Nets_GAN_Overview (0) | 2021.03.02 |