Generative Adversarial Nets_GAN_CODE

2021. 3. 2. 17:38Data Science/05_Research paper

반응형

 

Generative Adversarial Nets_GAN_(CODE)

 

: 2014년 공개된 GAN(Generative Adversarial Nets) 기반 Conv를 사용한 DCGAN 구현 

: Python Version: 3.6.8, tensorflow Version: 2.4.0 사용
    

GAN(generative Adversarial Nets_Overview

today-1.tistory.com/46

 

Generative Adversarial Nets_GAN_(Overview)

Generative Adversarial Nets_GAN_(Overview) : 2014년 공개된 GAN(Generative Adversarial Nets) 논문 리뷰를 진행 : 개념을 설명하는 Overview와 Tensor 2.0 구현 Code 2개 챕터로 소개 Generative Adversarial..

today-1.tistory.com


GAN_structure


 

 

1. Generator

 

: 생성모델(G)의 경우 4개의 Conv layers 를 쌓음

: Conv로 upsampling을 하기 위해서 Conv2DTranspose, stirde=2를 사용

 

Conv2DTranspose Tensor

www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose

 

tf.keras.layers.Conv2DTranspose  |  TensorFlow Core v2.4.1

Transposed convolution layer (sometimes called Deconvolution).

www.tensorflow.org

# conv_transpose

def conv_transpose_block(inputs, num_filters, kernel_size, init, strides, treatment=True):
    x = Conv2DTranspose(
                                        filters=num_filters,
                                        kernel_size=kernel_size,
                                        kernel_initializer=init,
                                        padding="same",
                                        strides=strides,
                                        use_bias=False,
                                        )(inputs)
    if treatment:
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.2)(x)

    return x



# generator

def build_generator(g_h, g_w, g_c, k_size, strides, latent_dim, mean, std, alpha):
    c_list = [512, 256, 128, 3]
    times = len(c_list)

    noise = Input(shape=(latent_dim,), name="generator_noise_input")
    x = Dense(g_h * g_w * g_c, use_bias=False)(noise)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Reshape((g_h, g_w, g_c))(x)

    for i in range(times):
        x = conv_transpose_block(
                                                 inputs = x,
                                                 num_filters = c_list[i],
                                                 kernel_size = k_size,
                                                 init = norm_init(mean=mean, std=std),
                                                 strides = strides,
                                                 treatment=True,
                                                 )
 
    generator_output = Activation("tanh")(x)

    return Model(noise, generator_output, name='generator')

 

 

2. discriminator

 

: 판별모델(D)의 경우 생성모델(G)과 같이 4개의 Conv layers 를 쌓음

: Conv로 Downsampling을 하기 위해서 Conv2D, stirde=2를 사용

# conv_block

def conv_block(inputs, num_filters, kernel_size, init, strides, alpha, treatment=True):
    x = Conv2D(
                       filters=num_filters,
                       kernel_size=kernel_size,
                       kernel_initializer=init,
                       padding="same",
                       strides=strides,
                       )(inputs)

    if treatment:
       x = BatchNormalization()(x)
       x = LeakyReLU(alpha=alpha)(x)

    return x


# discriminator

def build_discriminator(d_h, d_w, d_c, k_size, strides, mean, std, alpha):
    c_list = [64, 128, 256, 512]
    times = len(c_list)
    input_image = Input(shape=(d_h, d_w, d_c))
    x = input_image

    for i in range(times):
        x = conv_block(
                                inputs = x,
                                num_filters = c_list[i],
                                kernel_size = k_size,
                                init = norm_init(mean=mean, std=std),
                                strides = strides,
                                alpha = alpha,
                                treatment = True,
                                )

    x = Flatten()(x)
    discriminator_output = Dense(1)(x)

    return Model(input_image, discriminator_output, name="discriminator")

 

 

3. Generative Adversarial Nets(GAN)

 

class gan(Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(gan, self).__init__(self)
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(gan, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

#@tf.function
def train_step(self, images):
    batch_size = tf.shape(images).numpy()[0]
    noise = tf.random.normal([batch_size, self.latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = self.generator(noise, training=True)

        real_output = self.discriminator(images, training=True)
        fake_output = self.discriminator(generated_images, training=True)

        gen_loss = self.g_loss_fn(tf.ones_like(fake_output), fake_output)
        real_loss = self.d_loss_fn(tf.ones_like(real_output), real_output)
        fake_loss = self.d_loss_fn(tf.zeros_like(fake_output), fake_output)
        disc_loss = real_loss + fake_loss

        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)

    self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))     
    self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))

    return print(
                        "real_loss :",real_loss.numpy(),
                        "fake_loss :",fake_loss.numpy(),
                        "disc_loss :",disc_loss.numpy(),
                        "gen_loss :",gen_loss.numpy(),
                         ), [disc_loss.numpy(), gen_loss.numpy()]

 

 

반응형