๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

CS/์ธ๊ณต์ง€๋Šฅ

[Tensorflow keras] image generation using Stable Diffusion

728x90

 

๋ชฉํ‘œ

 

TensorFlow Keras library ๋ฅผ ์‚ฌ์šฉํ•ด์„œ CIFAR-10 dataset ๊ฐ™์€ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฐ„๋‹จํ•œ diffusion ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด๋ณผ ๊ฒƒ์ด๋‹ค.

๋‹ค์Œ์€ diffusion ๋ชจ๋ธ์ด ๋™์ž‘ํ•˜๋Š” ๋ฐฉ์‹์ด๊ณ , fully noised ์ด๋ฏธ์ง€์—์„œ๋ถ€ํ„ฐ ์ ์ง„์ ์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ด๋‚˜๊ฐ„๋‹ค.

 

 

 

 

 

 

์œ„์—์„œ ๋ดค๋“ฏ์ด, ์ด ๋ชจ๋ธ์€ ์ ์ง„์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๋ฅผ ์ค„์—ฌ๊ฐ€๋ฉด์„œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•œ๋‹ค. ๋”ฐ๋ผ์„œ ๋ณธ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์—๊ฒŒ noise ๋ฅผ ์ค„์ด๋Š” ๋ฐฉ๋ฒ•์„ ํ•™์Šต์‹œํ‚ฌ ๊ฒƒ์ด๋ฉฐ ์ด๋ฅผ ์œ„ํ•ด ๋‘๊ฐ€์ง€ ์ž…๋ ฅ์ด ํ•„์š”ํ•˜๋‹ค.

 

- ์ž…๋ ฅ : ์ฒ˜๋ฆฌ๋˜์–ด์•ผํ•  ์ดˆ๊ธฐ Noise 

- timestamp : noise status 

 

Importing necessary libraries

 

import numpy as np

from tqdm.auto import trange, tqdm
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers

 

Prepare the dataset

CIFAR-10 ๋ฐ์ดํ„ฐ ์…‹์œผ๋กœ๋ถ€ํ„ฐ ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜จ๋‹ค.

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train[y_train.squeeze() == 1]
X_train = (X_train / 127.5) - 1.0

 

Define variables

IMG_SIZE = 32     # input image size, CIFAR-10 is 32x32
BATCH_SIZE = 128
timesteps = 16    # how many steps for a noisy image into clear
time_bar = 1 - np.linspace(0, 1.0, timesteps + 1) # linspace for timesteps

 

time_bar๋Š” ํ™•์‚ฐ ๋ชจ๋ธ(Diffusion Model)์—์„œ ์‚ฌ์šฉ๋˜๋Š” ์‹œ๊ฐ„ ์Šค์ผ€์ผ๋ง(time scaling)์— ๊ด€๋ จ๋œ ๊ฐ’์ด๋‹ค.

์ด ๊ณผ์ •์—์„œ ์‚ฌ์šฉ๋˜๋Š” ํ•ต์‹ฌ ๊ฐœ๋… ์ค‘ ํ•˜๋‚˜๊ฐ€ ๋…ธ์ด์ฆˆ์˜ ์Šค์ผ€์ผ์„ ์‹œ๊ฐ„์— ๋”ฐ๋ผ ์กฐ์ ˆํ•˜๋Š” ๊ฒƒ์ธ๋ฐ, 

time_bar๋Š” 0์—์„œ 1๊นŒ์ง€์˜ ๊ฐ’์„ ๊ฐ€์ง€๋Š” ๋ฒกํ„ฐ๋กœ, ํ•ด๋‹น ๊ฐ’์€ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ผ์„ ๋‚˜ํƒ€๋‚ธ๋‹ค. ๋ณดํ†ต 0์—์„œ 1๋กœ ๋ณ€ํ•˜๋Š” time_bar๋Š” ๊ฐ ๋‹จ๊ณ„์—์„œ์˜ ๋…ธ์ด์ฆˆ์˜ ๊ฐ•๋„๋ฅผ ์กฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋œ๋‹ค. time_bar[0]์€ ์ดˆ๊ธฐ ๋‹จ๊ณ„์—์„œ์˜ ๋…ธ์ด์ฆˆ ๊ฐ•๋„๋ฅผ ๋‚˜ํƒ€๋‚ด๊ณ , time_bar[timesteps]๋Š” ๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„์—์„œ์˜ ๋…ธ์ด์ฆˆ ๊ฐ•๋„๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค. ๋…ธ์ด์ฆˆ์˜ ๊ฐ•๋„๊ฐ€ ์ ์ฐจ์ ์œผ๋กœ ๊ฐ์†Œํ•˜๋„๋ก ์„ค๊ณ„๋˜์–ด ์žˆ์–ด, ์ดˆ๊ธฐ ๋‹จ๊ณ„์—์„œ๋Š” ๋†’์€ ๋…ธ์ด์ฆˆ๊ฐ€ ์ถ”๊ฐ€๋˜๊ณ  ๋‚˜์ค‘์—๋Š” ์ €๋ ˆ๋ฒจ์˜ ๋…ธ์ด์ฆˆ๋งŒ ๋‚จ๊ฒŒ ๋œ๋‹ค.

Some utility functions for preview data

์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ์ •๊ทœํ™”ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜๋“ค์„ ํฌํ•จํ•œ๋‹ค.

def cvtImg(img):
    img = img - img.min() # Subtract the minimum value of the image array to make it zero-centered.
    img = (img / img.max()) # Normalize the image to be between 0 and 1.
    return img.astype(np.float32) # Return the normalized image as float32 type.

def show_examples(x):
    plt.figure(figsize=(10, 10)) # Create a new plot window of size 10x10 inches.
    for i in range(25): # Generate 25 subplots and plot each image on each subplot.
        plt.subplot(5, 5, i+1)
        img = cvtImg(x[i]) # Normalize the image using the cvtImg function.
        plt.imshow(img) # Plot the normalized image.
        plt.axis('off') # Remove axes to display the image cleanly.

show_examples(X_train)

 

1. cvtImg(img) ํ•จ์ˆ˜ : ์ด๋ฏธ์ง€๋ฅผ ์ •๊ทœํ™”ํ•œ๋‹ค.

2. show_examples(X_train) : X_train์— ์žˆ๋Š” ์ฒ˜์Œ 25๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ์ •๊ทœํ™”ํ•˜๊ณ  ์‹œ๊ฐ์ ์œผ๋กœ ํ‘œ์‹œ

์ด ๋•Œ, ํ•˜๋‚˜์˜ ํ”Œ๋กฏ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ํ•˜๋‚˜์˜ ๋„ํ‘œ๋ฅผ ๋งํ•˜๋ฉฐ ์„œ๋ธŒ ํ”Œ๋กฏ์ด๋ž€
ํ•˜๋‚˜์˜ ํ”Œ๋กฏ ์˜์—ญ ๋‚ด์—์„œ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์ž‘์€ ๋„ํ‘œ๋ฅผ ๋ฐฐ์—ด ํ˜•ํƒœ๋กœ ๋ฐฐ์น˜ํ•˜๋Š” ๊ฒƒ์„ ๋งํ•œ๋‹ค. 

Forward process

def forward_noise(x, t):
    a = time_bar[t]      # base on t
    b = time_bar[t + 1]  # image for t + 1

    # Make a Gaussian noise to be applied on x
    noise = np.random.randn(*x.shape)

    a = a.reshape((-1, 1, 1, 1))
    b = b.reshape((-1, 1, 1, 1))
    img_a = x * (1 - a) + noise * a # computes the image at time t
    img_b = x * (1 - b) + noise * b # computes the image at time t + 1
    return img_a, img_b
    
    
 def generate_ts(num):
    return np.random.randint(0, timesteps, size=num)

t = generate_ts(25)             # random for training data
a, b = forward_noise(X_train[:25], t)
show_examples(a)

 

 

- a,b : ๊ฐ๊ฐ ์‹œ๊ฐ„ t์™€ t + 1์—์„œ ์ถ”๊ฐ€ํ•  ๋…ธ์ด์ฆˆ์˜ ์–‘์„ ๊ฒฐ์ •

- random.randn() : ํ‰๊ท ์ด 0์ด๊ณ  ํ‘œ์ค€ํŽธ์ฐจ๊ฐ€ 1์ธ ์ •๊ทœ ๋ถ„ํฌ(๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ)๋ฅผ ๋”ฐ๋ฅด๋Š” ๋‚œ์ˆ˜๋ฅผ ์ƒ์„ฑ. *์„ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ๋Š” x.shape์˜ ํŠœํ”Œ์„ unpackํ•˜์—ฌ np.random.randn ํ•จ์ˆ˜์— ์ „๋‹ฌํ•˜๊ธฐ ์œ„ํ•จ์ด๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด x.shape๊ฐ€ (32, 32, 3) ์ธ ๊ฒฝ์šฐ, *x.shape๋Š” 32, 32, 3์œผ๋กœ ํ’€๋ฆฐ๋‹ค. np.random.randn(*x.shape)๋Š” np.random.randn(32, 32, 3)๊ณผ ๊ฐ™์ด ํ˜ธ์ถœ๋˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. 

์™ผ์ชฝ ์‚ฌ์ง„์ด Xt ์‹œ์ , ๊ทธ๋ฆฌ๊ณ  ์˜ค๋ฅธ์ชฝ ์‚ฌ์ง„์ด Xt+1 ์‹œ์ ์˜ ์ด๋ฏธ์ง€๋“ค์ด๋‹ค. 

 

 

Q. ์Šค์นผ๋ผ ๊ฐ’์ธ a,b ๋ฅผ reshape((-1, 1, 1, 1))ํ•˜๋Š” ์ด์œ ?

A. a๊ฐ€ ์Šค์นผ๋ผ์ด๊ธฐ ๋•Œ๋ฌธ์— ์ถ”๊ฐ€์ ์ธ ์ฐจ์›์„ ๋ถ€์—ฌํ•˜๋”๋ผ๋„ ์‹ค์ œ๋กœ๋Š” ํ˜•ํƒœ์— ํฐ ๋ณ€ํ™”๊ฐ€ ์—†์ง€๋งŒ, ์ฝ”๋“œ์˜ ์ผ๊ด€์„ฑ์„ ์œ ์ง€ํ•˜๊ณ  ๋‹ค์ฐจ์› ๋ฐฐ์—ด์„ ๋‹ค๋ฃจ๊ธฐ ์‰ฝ๊ฒŒ ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋œ๋‹ค. ์ธํ’‹ ๋ฐ์ดํ„ฐ์ธ x ์™€์˜ ์—ฐ์‚ฐ์„ ํŽธ๋ฆฌํ•˜๊ฒŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค. 

Building a block

๊ฐ๊ฐ์˜ ๋ธ”๋Ÿญ์€ Time Parameter ์™€ ํ•จ๊ป˜ ๋‘๊ฐœ์˜ convolutional networks๋ฅผ ํฌํ•จํ•˜๋ฉฐ, ์ด๊ฒƒ์€ information์— ๋”ฐ๋ผ ํ˜„์žฌ time step ๊ณผ ๊ฒฐ๊ณผ๋ฅผ ๊ฒฐ์ •ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค. ์•„๋ž˜์™€ ๊ฐ™์€ ํ”Œ๋กœ์šฐ ์ฐจํŠธ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ๊ณ , x_img ๋Š” noizy ํ•œ input image ์ด๊ณ  x_ts ๋Š” time step ๋ณ„ ์ž…๋ ฅ์ด๋‹ค. ์ฆ‰, ๋ณธ block ์€ ์ž…๋ ฅ ์ด๋ฏธ์ง€ (x_img)์—์„œ CNN์„ ์‚ฌ์šฉํ•˜์—ฌ ํŠน์ง•์„ ์ถ”์ถœํ•˜๊ณ , ์ด๋ฅผ ์‹œ๊ฐ„์ ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ (x_ts)์™€ ๊ฒฐํ•ฉํ•˜์—ฌ ์ตœ์ข…์ ์ธ ์ถœ๋ ฅ (x_out)์„ ์ƒ์„ฑํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. ์ด๋Ÿฌํ•œ ๊ณผ์ •์„ ํ†ตํ•ด diffusion ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€์˜ ํŠน์ง•์„ ์‹œ๊ฐ„์ ์ธ ์š”์†Œ์™€ ํ•จ๊ป˜ ๊ณ ๋ คํ•˜์—ฌ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

 

 

  • Conv2D: 128 channels, 3x3 conv, padding same, ReLU activation (input: x_img, output: x_parameter)
  • Time Parameter: dense layer with 128 hidden units (input: x_ts, output: time_parameter)
def block(x_img, x_ts):
    
    x_parameter = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x_img)
    x_out = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x_img)
    
    # Time parameter
    time_parameter = layers.Dense(128)(x_ts)
    
    # New x_parameter
    x_parameter = layers.Multiply()([x_parameter, time_parameter])
   
    # New x_out
    x_out = layers.Add()([x_out, x_parameter])
    
    
    x_out = layers.LayerNormalization()(x_out)
    x_out = layers.Activation('relu')(x_out)


    return x_out

 

 

Building an U-Net

Diffusion ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ณผ์ •์—์„œ ์›๋ณธ ์ด๋ฏธ์ง€์˜ ๊ตฌ์กฐ๋ฅผ ๋ณด์กดํ•˜๊ณ , ๊ฐ์ฒด์˜ ์ผ๋ถ€๋ถ„์„ ๋ณ€ํ˜•ํ•˜๊ฑฐ๋‚˜ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•œ๋‹ค. 

์ด๋ฅผ ์œ„ํ•ด U-Net์€ ๊ฐ์ฒด๋ฅผ ์ •ํ™•ํ•˜๊ฒŒ ๋ถ„ํ• ํ•˜์—ฌ ๊ฐ ๊ฐ์ฒด์˜ ๊ฒฝ๊ณ„๋ฅผ ๋ช…ํ™•ํžˆ ๊ตฌ๋ถ„ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ณผ์ •์—์„œ ๊ฐ์ฒด์˜ ์ •ํ™•ํ•œ ์ธ์‹๊ณผ ๋ฐฐ์น˜๋ฅผ ๋ณด์žฅํ•˜๋ฉฐ, ์ตœ์ข…์ ์œผ๋กœ ๋†’์€ ํ’ˆ์งˆ์˜ ์ƒ์„ฑ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ๊ธฐ์—ฌํ•œ๋‹ค.

 

 

U-Net ์ด๋ž€?

 

์ด๋ฏธ์ง€ ๋ถ„์„์—์„œ ์‚ฌ์šฉ๋˜๋Š” ๋”ฅ๋Ÿฌ๋‹ ์•„ํ‚คํ…์ฒ˜๋กœ, ์ด๋ฏธ์ง€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜(์ด๋ฏธ์ง€ ๋‚ด์˜ ํŠน์ • ๊ฐ์ฒด๋ฅผ ํ”ฝ์…€ ์ˆ˜์ค€์—์„œ ์‹๋ณ„ ๋ฐ ๋ถ„ํ• ํ•˜๋Š” ์ž‘์—…)์— ํŠนํ™”๋œ ๋„คํŠธ์›Œํฌ์ด๋‹ค. ์ธ์ฝ”๋”-๋””์ฝ”๋”(Encoder-Decoder) ๊ตฌ์กฐ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜๋ฉฐ, ํŠนํžˆ Fully Convolutional Network (FCN)์˜ ๋ณ€ํ˜•์œผ๋กœ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ธ์ฝ”๋” ๋ถ€๋ถ„์—์„œ๋Š” ์ด๋ฏธ์ง€์˜ ๊ณต๊ฐ„์ ์ธ ์ •๋ณด๋ฅผ ๋‹จ๊ณ„์ ์œผ๋กœ ์ถ•์†Œํ•ด๊ฐ€๋ฉฐ ์ถ”์ถœํ•˜๊ณ , ๋””์ฝ”๋” ๋ถ€๋ถ„์—์„œ๋Š” ์ด๋ฅผ ์—…์ƒ˜ํ”Œ๋งํ•˜์—ฌ ์›๋ณธ ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ๋กœ ๋ณต์›ํ•˜๋ฉด์„œ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

 

1) Contracting Path (์ธ์ฝ”๋”) : ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ์ ์  ์ค„์—ฌ๊ฐ€๋ฉฐ ํŠน์ง•์„ ์ถ”์ถœํ•˜๋Š” ๋‹จ๊ณ„์ด๋‹ค. ์ฃผ๋กœ Convolutional layer์™€ Pooling layer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ๊ณต๊ฐ„์ ์ธ ํ•ด์ƒ๋„๋ฅผ ์ถ•์†Œํ•œ๋‹ค.

 

2) Expanding Path (๋””์ฝ”๋”) : Contracting Path์—์„œ ์–ป์€ ํŠน์ง•์„ ๋ฐ”ํƒ•์œผ๋กœ ์ž…๋ ฅ ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ๋กœ ์—…์ƒ˜ํ”Œ๋งํ•˜๋ฉด์„œ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋งต์„ ์ƒ์„ฑํ•œ๋‹ค. 

 

3) Skip Connections : ์ธ์ฝ”๋”์˜ ๊ฐ ๋‹จ๊ณ„์—์„œ ํ•ด๋‹น ์Šคํ…์˜ ์ถœ๋ ฅ์„ ๋””์ฝ”๋”์— ์—ฐ๊ฒฐํ•˜์—ฌ ์„ธ๋ฐ€ํ•œ ํŠน์ง• ์ •๋ณด๊ฐ€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๊ฒฐ๊ณผ์— ๋ฐ˜์˜๋˜๋„๋ก ๋•๋Š”๋‹ค. ์ด๋Š” ๋„คํŠธ์›Œํฌ๊ฐ€ ์ •ํ™•ํ•œ ์œ„์น˜์— ๋Œ€ํ•œ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•œ๋‹ค.

 

4) Final Layer : ๋””์ฝ”๋”์˜ ๋งˆ์ง€๋ง‰ ์ธต์—์„œ๋Š” ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ํด๋ž˜์Šค ์ˆ˜์— ๋งž๋Š” ์ถœ๋ ฅ์„ ์ƒ์„ฑํ•œ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ๋Š” ๊ฐ ํ”ฝ์…€์ด ํ•ด๋‹น ํด๋ž˜์Šค์— ์†ํ•  ํ™•๋ฅ ์„ ๋‚˜ํƒ€๋‚ด๋Š” softmax activation function์„ ์‚ฌ์šฉํ•œ๋‹ค.

 

 

 

 

 

def make_model():
    x = x_input = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='x_input')

    x_ts = x_ts_input = layers.Input(shape=(1,), name='x_ts_input')
    x_ts = layers.Dense(192)(x_ts)
    x_ts = layers.LayerNormalization()(x_ts)
    x_ts = layers.Activation('relu')(x_ts)

    # ----- left ( down ) -----
    x = x32 = block(x, x_ts)
    x = layers.MaxPool2D(2)(x)

    x = x16 = block(x, x_ts)
    x = layers.MaxPool2D(2)(x)

    x = x8 = block(x, x_ts)
    x = layers.MaxPool2D(2)(x)

    x = x4 = block(x, x_ts)

    # ----- MLP -----
    x = layers.Flatten()(x)
    x = layers.Concatenate()([x, x_ts])
    x = layers.Dense(128)(x)
    x = layers.LayerNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Dense(4 * 4 * 32)(x)
    x = layers.LayerNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Reshape((4, 4, 32))(x)

    # ----- right ( up ) -----
    x = layers.Concatenate()([x, x4])
    x = block(x, x_ts)
    x = layers.UpSampling2D(2)(x)

    x = layers.Concatenate()([x, x8])
    x = block(x, x_ts)
    x = layers.UpSampling2D(2)(x)

    x = layers.Concatenate()([x, x16])
    x = block(x, x_ts)
    x = layers.UpSampling2D(2)(x)

    x = layers.Concatenate()([x, x32])
    x = block(x, x_ts)

    # ----- output -----
    x = layers.Conv2D(3, kernel_size=1, padding='same')(x)
    model = tf.keras.models.Model([x_input, x_ts_input], x)
    return model

model = make_model()

 

- tf.keras.models.Model([x_input, x_ts_input], x) : ์ž…๋ ฅ์œผ๋กœ x_input ๊ณผ x_ts_input ๋‘๊ฐœ์˜ ์ž…๋ ฅ์„ ๋ฐ›๊ณ , ์ถœ๋ ฅ์œผ๋กœ x๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋ชจ๋ธ์„ ์ •์˜ํ•œ๋‹ค.

def mse_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0008)
loss_func = mse_loss
model.compile(loss=loss_func, optimizer=optimizer)

 

์ผ๋ฐ˜์ ์œผ๋กœ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ž‘์—…์— ์ ํ•ฉํ•œ ์†์‹ค ํ•จ์ˆ˜๋กœ๋Š” ํ‰๊ท  ์ œ๊ณฑ ์˜ค์ฐจ(Mean Squared Error, MSE) ๋กœ์„œ,

์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€์™€ ๋ชฉํ‘œ ์ด๋ฏธ์ง€ ์‚ฌ์ด์˜ ํ”ฝ์…€๋ณ„ ์ฐจ์ด์˜ ์ œ๊ณฑ์„ ํ‰๊ท ํ•˜์—ฌ ๊ณ„์‚ฐํ•œ๋‹ค. ์ด๋•Œ, ํ›ˆ๋ จ ๊ณผ์ •์—์„œ  y_true์™€ y_pred๊ฐ€ ์†์‹คํ•จ์ˆ˜๋กœ ์ž๋™์œผ๋กœ ์ „๋‹ฌ๋œ๋‹ค.

 

 

Predict the result

๋…ธ์ด์ฆˆ ์ด๋ฏธ์ง€๋กœ๋ถ€ํ„ฐ ์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ณผ์ •์„ ๋‚˜ํƒ€๋‚ด๊ณ  ์žˆ๋‹ค. 

def predict(x_idx=None):
    x = np.random.normal(size=(32, IMG_SIZE, IMG_SIZE, 3))
    for i in trange(timesteps):
        t = i
        x = model.predict([x, np.full((32), t)], verbose=0)
    show_examples(x)

predict()

 

- ๋ชจ๋ธ์— ์ž…๋ ฅ๋  ์ดˆ๊ธฐ ์ž…๋ ฅ ์ด๋ฏธ์ง€ x๋ฅผ ์ •๊ทœ ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅด๋Š” ๋‚œ์ˆ˜๋กœ ์ƒ์„ฑํ•œ๋‹ค.

- timesteps ์€ ์˜ˆ์ธกํ•˜๋Š” ๋™์•ˆ ์ง„ํ–‰ํ•  ์‹œ๊ฐ„ ๋‹จ๊ณ„์ด๋‹ค. 

- ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž…๋ ฅ ์ด๋ฏธ์ง€ x์™€ ํ˜„์žฌ ์‹œ๊ฐ„ t๋ฅผ ์ด์šฉํ•ด ๋‹ค์Œ ์‹œ๊ฐ„ ๋‹จ๊ณ„์˜ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•œ๋‹ค. 

- ์ด ๋•Œ, np.full((32), t) ๋Š” ํฌ๊ธฐ๊ฐ€ 32์ด๊ณ  ๋ชจ๋“  ์š”์†Œ๊ฐ€  t๋กœ ์ฑ„์›Œ์ง„ ๋ฐฐ์—ด์ด๋‹ค. ๋ชจ๋“  ๋ฐฐ์น˜ ์ƒ˜ํ”Œ์— ๋™์ผํ•œ ์‹œ๊ฐ„ ๋‹จ๊ณ„ t๋ฅผ ์ ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋œ๋‹ค.

- ๊ฒฐ๋ก ์ ์œผ๋กœ ์ดˆ๊ธฐ ๋žœ๋ค์ด๋ฏธ์ง€์—์„œ t๋ฒˆ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์ถœ๋ ฅ๋œ๋‹ค.

 

 

Split the dataset into train and test sets

 

์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ณผ์ •์—์„œ ์˜ˆ์ธก ์Šคํ…์„ ์ˆ˜ํ–‰ํ•˜๊ณ , ๋งค ๋‘ ๋ฒˆ์งธ ์‹œ๊ฐ„ ๋‹จ๊ณ„๋งˆ๋‹ค ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ ์‹œ๊ฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค. 

def predict_step():
    xs = []
    x = np.random.normal(size=(8, IMG_SIZE, IMG_SIZE, 3))

    for i in trange(timesteps):
        t = i
        x = model.predict([x, np.full((8),  t)], verbose=0)
        if i % 2 == 0:
            xs.append(x[0])

    plt.figure(figsize=(20, 2))
    for i in range(len(xs)):
        plt.subplot(1, len(xs), i+1)
        plt.imshow(cvtImg(xs[i]))
        plt.title(f'{i}')
        plt.axis('off')

predict_step()

ํ•™์Šต ์ดํ›„ ์˜ˆ์ธก์„ ๋Œ๋ฆฐ ๊ฒฐ๊ณผ

Training model

def train_one(x_img):
    t = generate_ts(BATCH_SIZE)  # Generate random timesteps for the batch
    x_a, x_b = forward_noise(x_img, t)  # Generate noisy images
    x_ts = np.array(t).reshape(-1, 1)  # Reshape timesteps for model input

    loss = model.train_on_batch([x_a, x_ts], x_b)  # single gradient update on a single batch of data
    return loss

 

- ๋ชจ๋ธ์„ ํ•œ ๋ฐฐ์น˜(batch)๋งŒํผ ํ•™์Šต์‹œํ‚ค๊ณ  ๊ทธ ์†์‹ค ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค.

- ์ด๋ฅผ ์œ„ํ•ด์„œ forward_noise ํ•จ์ˆ˜๋กœ๋ถ€ํ„ฐ ๋…ธ์ด์ฆˆ๊ฐ€ ์ถ”๊ฐ€๋œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ณ , ํ•ด๋‹น ์ด๋ฏธ์ง€์™€ ํƒ€์ž„์Šคํ…์„ ์ž…๋ ฅ์œผ๋กœ ํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚จ๋‹ค.

- ์ฆ‰ ๋ชจ๋ธ์€ Xt ์‹œ์ ์˜ ์ด๋ฏธ์ง€ x_a ๋ฅผ ๊ฐ€์ง€๊ณ  ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ Xt+1 ์‹œ์ ์˜ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•˜๋ฉฐ, ์‹ค์ œ Xt+1 ์ด๋ฏธ์ง€์˜ ground truth ์ธ x_b ์™€์˜ ๋น„๊ต๋ฅผ ํ†ตํ•ด ์ตœ์ข… ์†์‹ค์„ ๊ตฌํ•ด๋‚ด๊ฒŒ ๋œ๋‹ค.

Model.train_on_batch( x, y=None, sample_weight=None, class_weight=None, return_dict=False )

 

 

def train(R=50):
    bar = trange(R)
    total = 100
    for i in bar:
        for j in range(total):
            # X_train ๋ฐฐ์—ด์—์„œ ์ž„์˜์˜ ์ธ๋ฑ์Šค๋ฅผ ์„ ํƒํ•˜์—ฌ BATCH_SIZE ๋งŒํผ์˜ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ด
            x_img = X_train[np.random.randint(len(X_train), size=BATCH_SIZE)]
            loss = train_one(x_img)
            pg = (j / total) * 100
            if j % 5 == 0:
                bar.set_description(f'loss: {loss:.5f}, p: {pg:.2f}%')

 

- train() ํ•จ์ˆ˜๋Š” ํ•œ๋ฒˆ์˜ ์—ํฌํฌ๋ฅผ ์‹คํ–‰. 

- trange(R) : ์ง„ํ–‰ ์ƒํ™ฉ์„ ํ‘œ์‹œํ•˜๋Š” ์ง„ํ–‰ ๋ง‰๋Œ€๋ฅผ ์ƒ์„ฑ

- ๊ฐ ์—ํฌํฌ๋งˆ๋‹ค total ๋ฒˆ์˜ ๋ฐ˜๋ณต์„ ์ˆ˜ํ–‰. ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋ฅผ ์—ฌ๋Ÿฌ ๋ฐฐ์น˜๋กœ ๋‚˜๋ˆ„์–ด ์ฒ˜๋ฆฌํ•˜๋Š” ๊ณผ์ •์„ ์˜๋ฏธ.

- train_one ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ˜„์žฌ ๋ฐฐ์น˜์˜ ์ž…๋ ฅ ์ด๋ฏธ์ง€ x_img๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๊ณ  ์†์‹ค๊ฐ’์„ ๋ฐ˜ํ™˜๋ฐ›์Œ

 

for _ in range(10):
    train()
    model.optimizer.learning_rate = max(0.000001, model.optimizer.learning_rate * 0.9)

    predict()
    predict_step()
    plt.show()

๊ฒฐ๊ณผ

 

Ref

 

https://www.tensorflow.org/tutorials/generative/generate_images_with_stable_diffusion