๋ชฉํ
TensorFlow Keras library ๋ฅผ ์ฌ์ฉํด์ CIFAR-10 dataset ๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ฐ๋จํ diffusion ๋ชจ๋ธ์ ๋ง๋ค์ด๋ณผ ๊ฒ์ด๋ค.
๋ค์์ diffusion ๋ชจ๋ธ์ด ๋์ํ๋ ๋ฐฉ์์ด๊ณ , fully noised ์ด๋ฏธ์ง์์๋ถํฐ ์ ์ง์ ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ์์ฑํด๋๊ฐ๋ค.
์์์ ๋ดค๋ฏ์ด, ์ด ๋ชจ๋ธ์ ์ ์ง์ ์ผ๋ก ๋ ธ์ด์ฆ๋ฅผ ์ค์ฌ๊ฐ๋ฉด์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค. ๋ฐ๋ผ์ ๋ณธ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์๊ฒ noise ๋ฅผ ์ค์ด๋ ๋ฐฉ๋ฒ์ ํ์ต์ํฌ ๊ฒ์ด๋ฉฐ ์ด๋ฅผ ์ํด ๋๊ฐ์ง ์ ๋ ฅ์ด ํ์ํ๋ค.
- ์ ๋ ฅ : ์ฒ๋ฆฌ๋์ด์ผํ ์ด๊ธฐ Noise
- timestamp : noise status
Importing necessary libraries
import numpy as np
from tqdm.auto import trange, tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
Prepare the dataset
CIFAR-10 ๋ฐ์ดํฐ ์ ์ผ๋ก๋ถํฐ ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์จ๋ค.
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train[y_train.squeeze() == 1]
X_train = (X_train / 127.5) - 1.0
Define variables
IMG_SIZE = 32 # input image size, CIFAR-10 is 32x32
BATCH_SIZE = 128
timesteps = 16 # how many steps for a noisy image into clear
time_bar = 1 - np.linspace(0, 1.0, timesteps + 1) # linspace for timesteps
time_bar๋ ํ์ฐ ๋ชจ๋ธ(Diffusion Model)์์ ์ฌ์ฉ๋๋ ์๊ฐ ์ค์ผ์ผ๋ง(time scaling)์ ๊ด๋ จ๋ ๊ฐ์ด๋ค.
์ด ๊ณผ์ ์์ ์ฌ์ฉ๋๋ ํต์ฌ ๊ฐ๋ ์ค ํ๋๊ฐ ๋ ธ์ด์ฆ์ ์ค์ผ์ผ์ ์๊ฐ์ ๋ฐ๋ผ ์กฐ์ ํ๋ ๊ฒ์ธ๋ฐ,
time_bar๋ 0์์ 1๊น์ง์ ๊ฐ์ ๊ฐ์ง๋ ๋ฒกํฐ๋ก, ํด๋น ๊ฐ์ ๋ ธ์ด์ฆ ์ค์ผ์ผ์ ๋ํ๋ธ๋ค. ๋ณดํต 0์์ 1๋ก ๋ณํ๋ time_bar๋ ๊ฐ ๋จ๊ณ์์์ ๋ ธ์ด์ฆ์ ๊ฐ๋๋ฅผ ์กฐ์ ํ๋ ๋ฐ ์ฌ์ฉ๋๋ค. time_bar[0]์ ์ด๊ธฐ ๋จ๊ณ์์์ ๋ ธ์ด์ฆ ๊ฐ๋๋ฅผ ๋ํ๋ด๊ณ , time_bar[timesteps]๋ ๋ง์ง๋ง ๋จ๊ณ์์์ ๋ ธ์ด์ฆ ๊ฐ๋๋ฅผ ๋ํ๋ธ๋ค. ๋ ธ์ด์ฆ์ ๊ฐ๋๊ฐ ์ ์ฐจ์ ์ผ๋ก ๊ฐ์ํ๋๋ก ์ค๊ณ๋์ด ์์ด, ์ด๊ธฐ ๋จ๊ณ์์๋ ๋์ ๋ ธ์ด์ฆ๊ฐ ์ถ๊ฐ๋๊ณ ๋์ค์๋ ์ ๋ ๋ฒจ์ ๋ ธ์ด์ฆ๋ง ๋จ๊ฒ ๋๋ค.
Some utility functions for preview data
์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ์ ๊ทํํ๊ณ ์๊ฐํํ๋ ํจ์๋ค์ ํฌํจํ๋ค.
def cvtImg(img):
img = img - img.min() # Subtract the minimum value of the image array to make it zero-centered.
img = (img / img.max()) # Normalize the image to be between 0 and 1.
return img.astype(np.float32) # Return the normalized image as float32 type.
def show_examples(x):
plt.figure(figsize=(10, 10)) # Create a new plot window of size 10x10 inches.
for i in range(25): # Generate 25 subplots and plot each image on each subplot.
plt.subplot(5, 5, i+1)
img = cvtImg(x[i]) # Normalize the image using the cvtImg function.
plt.imshow(img) # Plot the normalized image.
plt.axis('off') # Remove axes to display the image cleanly.
show_examples(X_train)
1. cvtImg(img) ํจ์ : ์ด๋ฏธ์ง๋ฅผ ์ ๊ทํํ๋ค.
2. show_examples(X_train) : X_train์ ์๋ ์ฒ์ 25๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ์ ๊ทํํ๊ณ ์๊ฐ์ ์ผ๋ก ํ์
์ด ๋, ํ๋์ ํ๋กฏ์ ์ผ๋ฐ์ ์ผ๋ก ํ๋์ ๋ํ๋ฅผ ๋งํ๋ฉฐ ์๋ธ ํ๋กฏ์ด๋
ํ๋์ ํ๋กฏ ์์ญ ๋ด์์ ์ฌ๋ฌ ๊ฐ์ ์์ ๋ํ๋ฅผ ๋ฐฐ์ด ํํ๋ก ๋ฐฐ์นํ๋ ๊ฒ์ ๋งํ๋ค.
Forward process
def forward_noise(x, t):
a = time_bar[t] # base on t
b = time_bar[t + 1] # image for t + 1
# Make a Gaussian noise to be applied on x
noise = np.random.randn(*x.shape)
a = a.reshape((-1, 1, 1, 1))
b = b.reshape((-1, 1, 1, 1))
img_a = x * (1 - a) + noise * a # computes the image at time t
img_b = x * (1 - b) + noise * b # computes the image at time t + 1
return img_a, img_b
def generate_ts(num):
return np.random.randint(0, timesteps, size=num)
t = generate_ts(25) # random for training data
a, b = forward_noise(X_train[:25], t)
show_examples(a)
- a,b : ๊ฐ๊ฐ ์๊ฐ t์ t + 1์์ ์ถ๊ฐํ ๋ ธ์ด์ฆ์ ์์ ๊ฒฐ์
- random.randn() : ํ๊ท ์ด 0์ด๊ณ ํ์คํธ์ฐจ๊ฐ 1์ธ ์ ๊ท ๋ถํฌ(๊ฐ์ฐ์์ ๋ถํฌ)๋ฅผ ๋ฐ๋ฅด๋ ๋์๋ฅผ ์์ฑ. *์ ์ฌ์ฉํ๋ ์ด์ ๋ x.shape์ ํํ์ unpackํ์ฌ np.random.randn ํจ์์ ์ ๋ฌํ๊ธฐ ์ํจ์ด๋ค. ์๋ฅผ ๋ค์ด x.shape๊ฐ (32, 32, 3) ์ธ ๊ฒฝ์ฐ, *x.shape๋ 32, 32, 3์ผ๋ก ํ๋ฆฐ๋ค. np.random.randn(*x.shape)๋ np.random.randn(32, 32, 3)๊ณผ ๊ฐ์ด ํธ์ถ๋๋ ์ญํ ์ ํ๋ค.
Q. ์ค์นผ๋ผ ๊ฐ์ธ a,b ๋ฅผ reshape((-1, 1, 1, 1))ํ๋ ์ด์ ?
A. a๊ฐ ์ค์นผ๋ผ์ด๊ธฐ ๋๋ฌธ์ ์ถ๊ฐ์ ์ธ ์ฐจ์์ ๋ถ์ฌํ๋๋ผ๋ ์ค์ ๋ก๋ ํํ์ ํฐ ๋ณํ๊ฐ ์์ง๋ง, ์ฝ๋์ ์ผ๊ด์ฑ์ ์ ์งํ๊ณ ๋ค์ฐจ์ ๋ฐฐ์ด์ ๋ค๋ฃจ๊ธฐ ์ฝ๊ฒ ํ๋ ๋ฐ ๋์์ด ๋๋ค. ์ธํ ๋ฐ์ดํฐ์ธ x ์์ ์ฐ์ฐ์ ํธ๋ฆฌํ๊ฒ ์ํํ ์ ์๋ค.
Building a block
๊ฐ๊ฐ์ ๋ธ๋ญ์ Time Parameter ์ ํจ๊ป ๋๊ฐ์ convolutional networks๋ฅผ ํฌํจํ๋ฉฐ, ์ด๊ฒ์ information์ ๋ฐ๋ผ ํ์ฌ time step ๊ณผ ๊ฒฐ๊ณผ๋ฅผ ๊ฒฐ์ ํ ์ ์๋๋ก ํ๋ค. ์๋์ ๊ฐ์ ํ๋ก์ฐ ์ฐจํธ ์ด๋ฏธ์ง๋ฅผ ํ์ธํ ์ ์๊ณ , x_img ๋ noizy ํ input image ์ด๊ณ x_ts ๋ time step ๋ณ ์ ๋ ฅ์ด๋ค. ์ฆ, ๋ณธ block ์ ์ ๋ ฅ ์ด๋ฏธ์ง (x_img)์์ CNN์ ์ฌ์ฉํ์ฌ ํน์ง์ ์ถ์ถํ๊ณ , ์ด๋ฅผ ์๊ฐ์ ์ธ ํ๋ผ๋ฏธํฐ (x_ts)์ ๊ฒฐํฉํ์ฌ ์ต์ข ์ ์ธ ์ถ๋ ฅ (x_out)์ ์์ฑํ๋ ์ญํ ์ ํ๋ค. ์ด๋ฌํ ๊ณผ์ ์ ํตํด diffusion ๋ชจ๋ธ์ ์ด๋ฏธ์ง์ ํน์ง์ ์๊ฐ์ ์ธ ์์์ ํจ๊ป ๊ณ ๋ คํ์ฌ ์ฒ๋ฆฌํ ์ ์๊ฒ ๋๋ค.
- Conv2D: 128 channels, 3x3 conv, padding same, ReLU activation (input: x_img, output: x_parameter)
- Time Parameter: dense layer with 128 hidden units (input: x_ts, output: time_parameter)
def block(x_img, x_ts):
x_parameter = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x_img)
x_out = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x_img)
# Time parameter
time_parameter = layers.Dense(128)(x_ts)
# New x_parameter
x_parameter = layers.Multiply()([x_parameter, time_parameter])
# New x_out
x_out = layers.Add()([x_out, x_parameter])
x_out = layers.LayerNormalization()(x_out)
x_out = layers.Activation('relu')(x_out)
return x_out
Building an U-Net
Diffusion ๋ชจ๋ธ์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ณผ์ ์์ ์๋ณธ ์ด๋ฏธ์ง์ ๊ตฌ์กฐ๋ฅผ ๋ณด์กดํ๊ณ , ๊ฐ์ฒด์ ์ผ๋ถ๋ถ์ ๋ณํํ๊ฑฐ๋ ์ถ๊ฐํ ์ ์์ด์ผ ํ๋ค.
์ด๋ฅผ ์ํด U-Net์ ๊ฐ์ฒด๋ฅผ ์ ํํ๊ฒ ๋ถํ ํ์ฌ ๊ฐ ๊ฐ์ฒด์ ๊ฒฝ๊ณ๋ฅผ ๋ช ํํ ๊ตฌ๋ถํ ์ ์๋ค. ์ด๋ฏธ์ง ์์ฑ ๊ณผ์ ์์ ๊ฐ์ฒด์ ์ ํํ ์ธ์๊ณผ ๋ฐฐ์น๋ฅผ ๋ณด์ฅํ๋ฉฐ, ์ต์ข ์ ์ผ๋ก ๋์ ํ์ง์ ์์ฑ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๋ฐ ๊ธฐ์ฌํ๋ค.
U-Net ์ด๋?
์ด๋ฏธ์ง ๋ถ์์์ ์ฌ์ฉ๋๋ ๋ฅ๋ฌ๋ ์ํคํ ์ฒ๋ก, ์ด๋ฏธ์ง ์ธ๊ทธ๋ฉํ ์ด์ (์ด๋ฏธ์ง ๋ด์ ํน์ ๊ฐ์ฒด๋ฅผ ํฝ์ ์์ค์์ ์๋ณ ๋ฐ ๋ถํ ํ๋ ์์ )์ ํนํ๋ ๋คํธ์ํฌ์ด๋ค. ์ธ์ฝ๋-๋์ฝ๋(Encoder-Decoder) ๊ตฌ์กฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋ฉฐ, ํนํ Fully Convolutional Network (FCN)์ ๋ณํ์ผ๋ก ๋ณผ ์ ์๋ค. ์ธ์ฝ๋ ๋ถ๋ถ์์๋ ์ด๋ฏธ์ง์ ๊ณต๊ฐ์ ์ธ ์ ๋ณด๋ฅผ ๋จ๊ณ์ ์ผ๋ก ์ถ์ํด๊ฐ๋ฉฐ ์ถ์ถํ๊ณ , ๋์ฝ๋ ๋ถ๋ถ์์๋ ์ด๋ฅผ ์ ์ํ๋งํ์ฌ ์๋ณธ ์ด๋ฏธ์ง์ ํฌ๊ธฐ๋ก ๋ณต์ํ๋ฉด์ ์ธ๊ทธ๋ฉํ ์ด์ ์ ์ํํ๋ค.
1) Contracting Path (์ธ์ฝ๋) : ์ ๋ ฅ ์ด๋ฏธ์ง๋ฅผ ์ ์ ์ค์ฌ๊ฐ๋ฉฐ ํน์ง์ ์ถ์ถํ๋ ๋จ๊ณ์ด๋ค. ์ฃผ๋ก Convolutional layer์ Pooling layer๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง์ ๊ณต๊ฐ์ ์ธ ํด์๋๋ฅผ ์ถ์ํ๋ค.
2) Expanding Path (๋์ฝ๋) : Contracting Path์์ ์ป์ ํน์ง์ ๋ฐํ์ผ๋ก ์ ๋ ฅ ์ด๋ฏธ์ง์ ํฌ๊ธฐ๋ก ์ ์ํ๋งํ๋ฉด์ ์ธ๊ทธ๋ฉํ ์ด์ ๋งต์ ์์ฑํ๋ค.
3) Skip Connections : ์ธ์ฝ๋์ ๊ฐ ๋จ๊ณ์์ ํด๋น ์คํ ์ ์ถ๋ ฅ์ ๋์ฝ๋์ ์ฐ๊ฒฐํ์ฌ ์ธ๋ฐํ ํน์ง ์ ๋ณด๊ฐ ์ธ๊ทธ๋ฉํ ์ด์ ๊ฒฐ๊ณผ์ ๋ฐ์๋๋๋ก ๋๋๋ค. ์ด๋ ๋คํธ์ํฌ๊ฐ ์ ํํ ์์น์ ๋ํ ์ธ๊ทธ๋ฉํ ์ด์ ์ ์ํํ ์ ์๊ฒ ํ๋ค.
4) Final Layer : ๋์ฝ๋์ ๋ง์ง๋ง ์ธต์์๋ ์ธ๊ทธ๋ฉํ ์ด์ ํด๋์ค ์์ ๋ง๋ ์ถ๋ ฅ์ ์์ฑํ๋ค. ์ผ๋ฐ์ ์ผ๋ก๋ ๊ฐ ํฝ์ ์ด ํด๋น ํด๋์ค์ ์ํ ํ๋ฅ ์ ๋ํ๋ด๋ softmax activation function์ ์ฌ์ฉํ๋ค.
def make_model():
x = x_input = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='x_input')
x_ts = x_ts_input = layers.Input(shape=(1,), name='x_ts_input')
x_ts = layers.Dense(192)(x_ts)
x_ts = layers.LayerNormalization()(x_ts)
x_ts = layers.Activation('relu')(x_ts)
# ----- left ( down ) -----
x = x32 = block(x, x_ts)
x = layers.MaxPool2D(2)(x)
x = x16 = block(x, x_ts)
x = layers.MaxPool2D(2)(x)
x = x8 = block(x, x_ts)
x = layers.MaxPool2D(2)(x)
x = x4 = block(x, x_ts)
# ----- MLP -----
x = layers.Flatten()(x)
x = layers.Concatenate()([x, x_ts])
x = layers.Dense(128)(x)
x = layers.LayerNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Dense(4 * 4 * 32)(x)
x = layers.LayerNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Reshape((4, 4, 32))(x)
# ----- right ( up ) -----
x = layers.Concatenate()([x, x4])
x = block(x, x_ts)
x = layers.UpSampling2D(2)(x)
x = layers.Concatenate()([x, x8])
x = block(x, x_ts)
x = layers.UpSampling2D(2)(x)
x = layers.Concatenate()([x, x16])
x = block(x, x_ts)
x = layers.UpSampling2D(2)(x)
x = layers.Concatenate()([x, x32])
x = block(x, x_ts)
# ----- output -----
x = layers.Conv2D(3, kernel_size=1, padding='same')(x)
model = tf.keras.models.Model([x_input, x_ts_input], x)
return model
model = make_model()
- tf.keras.models.Model([x_input, x_ts_input], x) : ์ ๋ ฅ์ผ๋ก x_input ๊ณผ x_ts_input ๋๊ฐ์ ์ ๋ ฅ์ ๋ฐ๊ณ , ์ถ๋ ฅ์ผ๋ก x๋ฅผ ๋ฐํํ๋ ๋ชจ๋ธ์ ์ ์ํ๋ค.
def mse_loss(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true - y_pred))
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0008)
loss_func = mse_loss
model.compile(loss=loss_func, optimizer=optimizer)
์ผ๋ฐ์ ์ผ๋ก ์ด๋ฏธ์ง ์์ฑ ์์ ์ ์ ํฉํ ์์ค ํจ์๋ก๋ ํ๊ท ์ ๊ณฑ ์ค์ฐจ(Mean Squared Error, MSE) ๋ก์,
์์ฑ๋ ์ด๋ฏธ์ง์ ๋ชฉํ ์ด๋ฏธ์ง ์ฌ์ด์ ํฝ์ ๋ณ ์ฐจ์ด์ ์ ๊ณฑ์ ํ๊ท ํ์ฌ ๊ณ์ฐํ๋ค. ์ด๋, ํ๋ จ ๊ณผ์ ์์ y_true์ y_pred๊ฐ ์์คํจ์๋ก ์๋์ผ๋ก ์ ๋ฌ๋๋ค.
Predict the result
๋ ธ์ด์ฆ ์ด๋ฏธ์ง๋ก๋ถํฐ ์ค์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ณผ์ ์ ๋ํ๋ด๊ณ ์๋ค.
def predict(x_idx=None):
x = np.random.normal(size=(32, IMG_SIZE, IMG_SIZE, 3))
for i in trange(timesteps):
t = i
x = model.predict([x, np.full((32), t)], verbose=0)
show_examples(x)
predict()
- ๋ชจ๋ธ์ ์ ๋ ฅ๋ ์ด๊ธฐ ์ ๋ ฅ ์ด๋ฏธ์ง x๋ฅผ ์ ๊ท ๋ถํฌ๋ฅผ ๋ฐ๋ฅด๋ ๋์๋ก ์์ฑํ๋ค.
- timesteps ์ ์์ธกํ๋ ๋์ ์งํํ ์๊ฐ ๋จ๊ณ์ด๋ค.
- ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ ๋ ฅ ์ด๋ฏธ์ง x์ ํ์ฌ ์๊ฐ t๋ฅผ ์ด์ฉํด ๋ค์ ์๊ฐ ๋จ๊ณ์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํ๋ค.
- ์ด ๋, np.full((32), t) ๋ ํฌ๊ธฐ๊ฐ 32์ด๊ณ ๋ชจ๋ ์์๊ฐ t๋ก ์ฑ์์ง ๋ฐฐ์ด์ด๋ค. ๋ชจ๋ ๋ฐฐ์น ์ํ์ ๋์ผํ ์๊ฐ ๋จ๊ณ t๋ฅผ ์ ์ฉํ๊ธฐ ์ํด ์ฌ์ฉ๋๋ค.
- ๊ฒฐ๋ก ์ ์ผ๋ก ์ด๊ธฐ ๋๋ค์ด๋ฏธ์ง์์ t๋ฒ ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ ๊ฒฐ๊ณผ๊ฐ ์ถ๋ ฅ๋๋ค.
Split the dataset into train and test sets
์ด๋ฏธ์ง ์์ฑ ๊ณผ์ ์์ ์์ธก ์คํ ์ ์ํํ๊ณ , ๋งค ๋ ๋ฒ์งธ ์๊ฐ ๋จ๊ณ๋ง๋ค ์์ฑ๋ ์ด๋ฏธ์ง๋ฅผ ์๊ฐํํ๋ ํจ์์ด๋ค.
def predict_step():
xs = []
x = np.random.normal(size=(8, IMG_SIZE, IMG_SIZE, 3))
for i in trange(timesteps):
t = i
x = model.predict([x, np.full((8), t)], verbose=0)
if i % 2 == 0:
xs.append(x[0])
plt.figure(figsize=(20, 2))
for i in range(len(xs)):
plt.subplot(1, len(xs), i+1)
plt.imshow(cvtImg(xs[i]))
plt.title(f'{i}')
plt.axis('off')
predict_step()
Training model
def train_one(x_img):
t = generate_ts(BATCH_SIZE) # Generate random timesteps for the batch
x_a, x_b = forward_noise(x_img, t) # Generate noisy images
x_ts = np.array(t).reshape(-1, 1) # Reshape timesteps for model input
loss = model.train_on_batch([x_a, x_ts], x_b) # single gradient update on a single batch of data
return loss
- ๋ชจ๋ธ์ ํ ๋ฐฐ์น(batch)๋งํผ ํ์ต์ํค๊ณ ๊ทธ ์์ค ๊ฐ์ ๋ฐํํ๋ ํจ์์ด๋ค.
- ์ด๋ฅผ ์ํด์ forward_noise ํจ์๋ก๋ถํฐ ๋ ธ์ด์ฆ๊ฐ ์ถ๊ฐ๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ณ , ํด๋น ์ด๋ฏธ์ง์ ํ์์คํ ์ ์ ๋ ฅ์ผ๋ก ํ์ฌ ๋ชจ๋ธ์ ํ์ต์ํจ๋ค.
- ์ฆ ๋ชจ๋ธ์ Xt ์์ ์ ์ด๋ฏธ์ง x_a ๋ฅผ ๊ฐ์ง๊ณ ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ์ฌ Xt+1 ์์ ์ ์ด๋ฏธ์ง๋ฅผ ์์ธกํ๋ฉฐ, ์ค์ Xt+1 ์ด๋ฏธ์ง์ ground truth ์ธ x_b ์์ ๋น๊ต๋ฅผ ํตํด ์ต์ข ์์ค์ ๊ตฌํด๋ด๊ฒ ๋๋ค.
Model.train_on_batch( x, y=None, sample_weight=None, class_weight=None, return_dict=False )
def train(R=50):
bar = trange(R)
total = 100
for i in bar:
for j in range(total):
# X_train ๋ฐฐ์ด์์ ์์์ ์ธ๋ฑ์ค๋ฅผ ์ ํํ์ฌ BATCH_SIZE ๋งํผ์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์ด
x_img = X_train[np.random.randint(len(X_train), size=BATCH_SIZE)]
loss = train_one(x_img)
pg = (j / total) * 100
if j % 5 == 0:
bar.set_description(f'loss: {loss:.5f}, p: {pg:.2f}%')
- train() ํจ์๋ ํ๋ฒ์ ์ํฌํฌ๋ฅผ ์คํ.
- trange(R) : ์งํ ์ํฉ์ ํ์ํ๋ ์งํ ๋ง๋๋ฅผ ์์ฑ
- ๊ฐ ์ํฌํฌ๋ง๋ค total ๋ฒ์ ๋ฐ๋ณต์ ์ํ. ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ์ฌ๋ฌ ๋ฐฐ์น๋ก ๋๋์ด ์ฒ๋ฆฌํ๋ ๊ณผ์ ์ ์๋ฏธ.
- train_one ํจ์๋ฅผ ํธ์ถํ์ฌ ํ์ฌ ๋ฐฐ์น์ ์ ๋ ฅ ์ด๋ฏธ์ง x_img๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ จํ๊ณ ์์ค๊ฐ์ ๋ฐํ๋ฐ์
for _ in range(10):
train()
model.optimizer.learning_rate = max(0.000001, model.optimizer.learning_rate * 0.9)
predict()
predict_step()
plt.show()
๊ฒฐ๊ณผ
Ref
https://www.tensorflow.org/tutorials/generative/generate_images_with_stable_diffusion
'CS > ์ธ๊ณต์ง๋ฅ' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
SVM ์ ํ์ฉํ ์คํธ ๋ถ๋ฅ๊ธฐ ( Spam Classification via SVM ) (1) | 2024.10.17 |
---|---|
๋์ด๋ธ ๋ฒ ์ด์ฆ๋ฅผ ์ฌ์ฉํ ์คํธ ๋ฉ์ผ ๋ถ๋ฅ๊ธฐ (Spam Classification via Naรฏve Bayes) (3) | 2024.10.17 |
Simple Diffusion Image generate Model (0) | 2024.06.21 |
[TensorFlow Keras] ์๊ธ์จ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ GAN model ๋ง๋ค๊ธฐ (0) | 2024.05.29 |
Generative Adversarial Network (1) | 2024.05.28 |