๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

CS/์ธ๊ณต์ง€๋Šฅ

k-ํ‰๊ท  ์•Œ๊ณ ๋ฆฌ์ฆ˜ (k-means clustering) ๊ตฌํ˜„

728x90

K-means ํด๋Ÿฌ์Šคํ„ฐ๋ง์ด๋ž€?

์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ๋ฅผ K๊ฐœ์˜ ๊ทธ๋ฃน์œผ๋กœ ๋‚˜๋ˆ„๋Š” ๋น„์ง€๋„ ํ•™์Šต ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด๋‹ค. ์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๊ตฐ์ง‘(ํด๋Ÿฌ์Šคํ„ฐ)์˜ ์ค‘์‹ฌ์„ ๋ฐ˜๋ณต์ ์œผ๋กœ ์—…๋ฐ์ดํŠธํ•˜์—ฌ, ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๋ฅผ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์ค‘์‹ฌ(์„ผํŠธ๋กœ์ด๋“œ)์— ํ• ๋‹นํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋™์ž‘ํ•œ๋‹ค.

 

๋™์ž‘๋ฐฉ์‹

 

  1. K๊ฐœ์˜ ์ดˆ๊ธฐ ์ค‘์‹ฌ(centroid)์„ ๋ฌด์ž‘์œ„๋กœ ์„ ํƒํ•œ๋‹ค.
  2. ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๋ฅผ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์ค‘์‹ฌ์— ํ• ๋‹นํ•œ๋‹ค.
  3. ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ์˜ ์ค‘์‹ฌ์„ ๋‹ค์‹œ ๊ณ„์‚ฐํ•œ๋‹ค.
  4. ์ค‘์‹ฌ์ด ๋” ์ด์ƒ ๋ณ€ํ•˜์ง€ ์•Š๊ฑฐ๋‚˜, ์ง€์ •๋œ ๋ฐ˜๋ณต ํšŸ์ˆ˜์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ 2~3 ๋‹จ๊ณ„๋ฅผ ๋ฐ˜๋ณตํ•œ๋‹ค.

 

 

๋ชฉํ‘œ

K-means ํด๋Ÿฌ์Šคํ„ฐ๋ง์„ ๊ตฌํ˜„ํ•œ๋‹ค.

 

์ œ์•ฝ์กฐ๊ฑด

- ๋‘ ๋ฒˆ์˜ ๋ฐ˜๋ณต ๋™์•ˆ ๋ชจ๋“  ์ค‘์‹ฌ์ (centroids)์˜ ์œ„์น˜๊ฐ€ 1 * 10^-5 ์ดํ•˜๋กœ ๋ณ€ํ•  ๊ฒฝ์šฐ ์ˆ˜๋ ดํ–ˆ๋‹ค๊ณ  ๊ฐ„์ฃผํ•œ๋‹ค.

- K-means ํด๋Ÿฌ์Šคํ„ฐ๋ง๊ณผ ๊ด€๋ จ๋œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ์ ˆ๋Œ€ ์‚ฌ์šฉํ•˜์ง€ ๋ง์•„์•ผ ํ•˜๋ฉฐ, ์ด ๊ณผ์ œ์—์„œ๋Š” ์ง์ ‘ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ตฌํ˜„ํ•ด์•ผํ•œ๋‹ค.

- ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ์ˆ˜๋ ดํ•  ๋•Œ๊นŒ์ง€ ๊ฐ ๋‹จ๊ณ„๋ฅผ ํ”Œ๋กฏ์œผ๋กœ ๊ทธ๋ฆฐ๋‹ค.์˜ˆ๋ฅผ ๋“ค์–ด, ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด 6๋ฒˆ์งธ ๋‹จ๊ณ„์—์„œ ์ˆ˜๋ ดํ–ˆ๋‹ค๋ฉด, 1๋‹จ๊ณ„๋ถ€ํ„ฐ 6๋‹จ๊ณ„๊นŒ์ง€์˜ ๊ทธ๋ฆผ์„ ์ œ๊ณตํ•ด์•ผ ํ•˜๋ฉฐ, ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ์ƒ‰์ƒ์œผ๋กœ ๋ช…ํ™•ํ•˜๊ฒŒ ๊ตฌ๋ถ„๋˜์–ด์•ผ ํ•œ๋‹ค.

 

 

๋ฐ์ดํ„ฐ์…‹ 

 

- 2๊ฐœ์˜ ํด๋Ÿฌ์Šคํ„ฐ(K=2)๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ด๋ฉฐ 2D ํฌ์ธํŠธ ๋ฐ์ดํ„ฐ๋Š” data_2d.csv ํŒŒ์ผ์„ ์‚ฌ์šฉ

- ์ดˆ๊ธฐ ์ค‘์‹ฌ์ (centroids) ์œ„์น˜๋Š” ๋‘ ํด๋Ÿฌ์Šคํ„ฐ์— ๋Œ€ํ•œ init_centroids.csv ํŒŒ์ผ์— ์ œ๊ณต๋œ ์œ„์น˜๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ด๋‹ค.

 

init_centroids.csv

3.323451927212228152e-01,1.798198342198527866e-01
2.241161287440249505e-01,2.219986906322291287e-01

 

data_2d.csv

0.000000000000000000e+00,-7.687164597386728637e-01,4.608603078297135447e-01
0.000000000000000000e+00,2.687847555392556487e+00,2.366960661575847169e+00
0.000000000000000000e+00,-2.013793555022345139e-01,4.704299346653586511e-01
0.000000000000000000e+00,6.084956800449090597e-01,1.225400029138742575e+00
0.000000000000000000e+00,-8.228190446259109336e-02,1.137218118753473339e+00
0.000000000000000000e+00,2.083069297959621036e+00,2.694482088909215811e+00
0.000000000000000000e+00,1.503019851143946983e+00,1.074847268552238111e+00
0.000000000000000000e+00,3.916620013534907185e-01,-2.874971661363743269e-01
0.000000000000000000e+00,3.213771110785266227e-01,1.296743009602315366e+00
0.000000000000000000e+00,5.912482577647957260e-01,1.267164122169239793e-01
0.000000000000000000e+00,1.150577634973361407e+00,-2.664038442463685374e-01
0.000000000000000000e+00,9.425866685920466503e-01,8.676624226337872337e-01
0.000000000000000000e+00,1.357806126580951567e+00,1.805471547458144421e+00
0.000000000000000000e+00,1.162919909687994968e+00,2.622430134800965540e+00
0.000000000000000000e+00,-9.786851243616156992e-02,1.012305814636828893e+00
0.000000000000000000e+00,8.577741746560831881e-01,1.031965247701047028e+00
0.000000000000000000e+00,6.834367317155296551e-01,1.578139963641977950e-02
0.000000000000000000e+00,1.543771853980922426e+00,1.750230549650776624e+00

 

๊ตฌํ˜„

 

๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

import numpy as np
import matplotlib.pyplot as plt

 

๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

๋ฐ์ดํ„ฐ ํฌ์ธํŠธ ํŒŒ์ผ ์ฝ๊ธฐ

# ๋ฐ์ดํ„ฐ ํŒŒ์ผ์—์„œ ์ขŒํ‘œ ์ฝ๊ธฐ
def read_data(file_path):
    data = []
    labels = []

    with open(file_path, 'r') as file:
        for line in file:
            values = line.strip().split(',')
            # ์ฒซ ๋ฒˆ์งธ ๊ฐ’(ํด๋Ÿฌ์Šคํ„ฐ ID)์€ ๋ฌด์‹œํ•˜๊ณ  ๋‘ ๋ฒˆ์งธ, ์„ธ ๋ฒˆ์งธ ์ขŒํ‘œ๋งŒ ์‚ฌ์šฉ
            data.append([float(values[1]), float(values[2])])
            labels.append(int(float(values[0]))) # ์ดˆ๊ธฐ ํด๋Ÿฌ์Šคํ„ฐ ๋ผ๋ฒจ
    return np.array(data), np.array(labels)

 

  • hw2_data_2d.csv ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ์–ด์˜จ๋‹ค. ๊ฐ ํ–‰์€ ํด๋Ÿฌ์Šคํ„ฐ ๋ ˆ์ด๋ธ”, x ์ขŒํ‘œ, y ์ขŒํ‘œ์˜ ํ˜•์‹์œผ๋กœ ๋˜์–ด ์žˆ๋‹ค.
  • ์ฒซ ๋ฒˆ์งธ ๊ฐ’(ํด๋Ÿฌ์Šคํ„ฐ ๋ ˆ์ด๋ธ”)์„ ์ œ์™ธํ•˜๊ณ  x, y ์ขŒํ‘œ๋งŒ์„ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉํ•˜๋ฉฐ, ํด๋Ÿฌ์Šคํ„ฐ ๋ ˆ์ด๋ธ”์€ ๋”ฐ๋กœ ์ €์žฅํ•œ๋‹ค.

 

์ค‘์‹ฌ์  ์ขŒํ‘œ ํŒŒ์ผ ์ฝ๊ธฐ

 

[(x1,y1),(x2,y2)] ํ˜•ํƒœ์˜ numpy ๋ฐฐ์—ด๋กœ ์ €์žฅํ•œ๋‹ค. 

 

def read_centroid(file_path):
    return np.array(np.loadtxt(file_path, delimiter=',', dtype=float))

 

 

Kmeans ํด๋ž˜์Šค ์ƒ์„ฑ

Kmeans ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ตฌํ˜„์„ ์œ„ํ•œ ํด๋ž˜์Šค๋ฅผ ๋งŒ๋“ ๋‹ค. ๋‹ค์Œ์€ ์ „์ฒด ํด๋ž˜์Šค ์ฝ”๋“œ์ด๋‹ค.

# K-means Clustering ํด๋ž˜์Šค
class KMeans:
    def __init__(self,labels, n_clusters, max_iter=300, tol=1e-5, init_centroids=None):
        self.n_clusters = n_clusters  # ํด๋Ÿฌ์Šคํ„ฐ ๊ฐœ์ˆ˜ (K=2)
        self.max_iter = max_iter  # ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
        self.tol = tol  # ์ˆ˜๋ ด ๊ธฐ์ค€ (๋ณ€ํ™”์œจ์ด ์ด๋ณด๋‹ค ์ž‘์•„์ง€๋ฉด ์ค‘์ง€)
        self.centroids = init_centroids  # ์ดˆ๊ธฐ ํด๋Ÿฌ์Šคํ„ฐ ์ค‘์‹ฌ [(x1,y1),(x2,y2)]
        self.labels = labels # ๊ฐ ์ขŒํ‘œ๊ฐ€ ์–ด๋–ค ํด๋Ÿฌ์Šคํ„ฐ์— ์†ํ•˜๋Š”์ง€ 


    def fit(self, X):

        # 0 ๋‹จ๊ณ„ ์‹œ๊ฐํ™” ์ถœ๋ ฅ
        self.plot_step(X,0)

        

        # ์ดˆ๊ธฐ ์ค‘์‹ฌ ์„ค์ •
        for i in range(self.max_iter):

            # ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ ์ค‘์‹ฌ์„ ์—…๋ฐ์ดํŠธ
            self.labels = np.array([self.closest_centroid(x) for x in X])
            
            # ์ด์ „ ์ค‘์‹ฌ์„ ์ €์žฅ (์ค‘์‹ฌ ์—…๋ฐ์ดํŠธ ํ›„ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•ด)
            old_centroids = self.centroids.copy()
            
            # ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ๋งˆ๋‹ค ํ•ด๋‹น ํด๋Ÿฌ์Šคํ„ฐ์— ์†ํ•œ ์ขŒํ‘œ ๋ชฉ๋ก์„ ํ‰๊ท ํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ค‘์‹ฌ์„ ๊ณ„์‚ฐ
            for k in range(self.n_clusters):
                cluster_points = X[self.labels == k]
                if len(cluster_points) > 0:
                    self.centroids[k] = np.mean(cluster_points, axis=0) # ํด๋Ÿฌ์Šคํ„ฐ k ์˜ ์ค‘์‹ฌ์  ์ขŒํ‘œ ์—…๋ฐ์ดํŠธ 

            # ์ค‘์‹ฌ์ด ์–ผ๋งˆ๋‚˜ ์ด๋™ํ–ˆ๋Š”์ง€ ๊ณ„์‚ฐ
            centroid_shift = np.sum(np.linalg.norm(self.centroids - old_centroids, axis=1))
            
            # ์‹œ๊ฐํ™”: ๊ฐ ๋‹จ๊ณ„์—์„œ์˜ ๊ฒฐ๊ณผ ์ถœ๋ ฅ
            self.plot_step(X, i + 1)
            
            # ์ˆ˜๋ ด ์—ฌ๋ถ€ ํ™•์ธ
            if centroid_shift < self.tol:
                print(f"Converged after {i+1} iterations")
                break
            
    # ํŠน์ • ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ ์ค‘์‹ฌ์„ ์ฐพ์Œ
    def closest_centroid(self, x):
        distances = [euclidean_distance(x, centroid) for centroid in self.centroids]
        return np.argmin(distances) # ๊ฐ€์žฅ ๊ฑฐ๋ฆฌ๊ฐ€ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ์˜ ์ธ๋ฑ์Šค ๋ฐ˜ํ™˜ 
    
    
    # ๊ฐ ๋‹จ๊ณ„๋ณ„ ํด๋Ÿฌ์Šคํ„ฐ๋ง ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
    def plot_step(self, X, step):       
        
        cluster_color_set = ['blue','red']   

        # ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ ๋ณ„ ๊ทธ๋ฃน ์‹œ๊ฐํ™”
        for k, centroid in enumerate(self.centroids):
            plt.scatter(centroid[0], centroid[1], c=cluster_color_set[k], marker='x', s=200, label=f'Centroid {k}')
            plt.scatter(X[self.labels == k ][:,0], X[self.labels == k ][:,1], c=cluster_color_set[k])  
        
        plt.title(f"K-means Clustering (Step {step})")
        plt.legend()
        plt.show()

 

 

 

์ดˆ๊ธฐํ™”

   def __init__(self,labels, n_clusters, max_iter=300, tol=1e-5, init_centroids=None):
        self.n_clusters = n_clusters  # ํด๋Ÿฌ์Šคํ„ฐ ๊ฐœ์ˆ˜ (K=2)
        self.max_iter = max_iter  # ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
        self.tol = tol  # ์ˆ˜๋ ด ๊ธฐ์ค€ (๋ณ€ํ™”์œจ์ด ์ด๋ณด๋‹ค ์ž‘์•„์ง€๋ฉด ์ค‘์ง€)
        self.centroids = init_centroids  # ์ดˆ๊ธฐ ํด๋Ÿฌ์Šคํ„ฐ ์ค‘์‹ฌ [(x1,y1),(x2,y2)]
        self.labels = labels # ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๊ฐ€ ์†ํ•˜๋Š” ํด๋Ÿฌ์Šคํ„ฐ ๋ ˆ์ด๋ธ”.

 

  • n_clusters: ํด๋Ÿฌ์Šคํ„ฐ ๊ฐœ์ˆ˜(K), ์—ฌ๊ธฐ์„œ๋Š” 2๋กœ ์„ค์ •.
  • max_iter: ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜. ๊ธฐ๋ณธ๊ฐ’์€ 300์œผ๋กœ ์„ค์ •.
  • tol: ์ค‘์‹ฌ์ ์˜ ์ด๋™์ด ์ด ๊ฐ’๋ณด๋‹ค ์ž‘์œผ๋ฉด ์ˆ˜๋ ดํ–ˆ๋‹ค๊ณ  ํŒ๋‹จํ•˜๋Š” ๊ธฐ์ค€ (1e-5).
  • init_centroids: ์ดˆ๊ธฐ ์ค‘์‹ฌ์ ์˜ ์ขŒํ‘œ (csv ํŒŒ์ผ์—์„œ ๋ถˆ๋Ÿฌ์˜ด).
  • labels: ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๊ฐ€ ์†ํ•˜๋Š” ํด๋Ÿฌ์Šคํ„ฐ ๋ ˆ์ด๋ธ”.

 

 

๋ชจ๋ธ ํ•™์Šต 

์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ X์— ๋Œ€ํ•ด K-means ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ ์šฉํ•˜๋Š” ๋ฉ”์„œ๋“œ์ด๋‹ค.

   def fit(self, X):

        # 0 ๋‹จ๊ณ„ ์‹œ๊ฐํ™” ์ถœ๋ ฅ
        self.plot_step(X,0)

        # ์ดˆ๊ธฐ ์ค‘์‹ฌ ์„ค์ •
        for i in range(self.max_iter):

            # ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ ์ค‘์‹ฌ์„ ์—…๋ฐ์ดํŠธ
            self.labels = np.array([self.closest_centroid(x) for x in X])
            
            # ์ด์ „ ์ค‘์‹ฌ์„ ์ €์žฅ (์ค‘์‹ฌ ์—…๋ฐ์ดํŠธ ํ›„ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•ด)
            old_centroids = self.centroids.copy()
            
            # ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ๋งˆ๋‹ค ํ•ด๋‹น ํด๋Ÿฌ์Šคํ„ฐ์— ์†ํ•œ ์ขŒํ‘œ ๋ชฉ๋ก์„ ํ‰๊ท ํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ค‘์‹ฌ์„ ๊ณ„์‚ฐ
            for k in range(self.n_clusters):
                cluster_points = X[self.labels == k]
                if len(cluster_points) > 0:
                    self.centroids[k] = np.mean(cluster_points, axis=0) # ํด๋Ÿฌ์Šคํ„ฐ k ์˜ ์ค‘์‹ฌ์  ์ขŒํ‘œ ์—…๋ฐ์ดํŠธ 

            # ์ค‘์‹ฌ์ด ์–ผ๋งˆ๋‚˜ ์ด๋™ํ–ˆ๋Š”์ง€ ๊ณ„์‚ฐ
            centroid_shift = np.sum(np.linalg.norm(self.centroids - old_centroids, axis=1))
            
            # ์‹œ๊ฐํ™”: ๊ฐ ๋‹จ๊ณ„์—์„œ์˜ ๊ฒฐ๊ณผ ์ถœ๋ ฅ
            self.plot_step(X, i + 1)
            
            # ์ˆ˜๋ ด ์—ฌ๋ถ€ ํ™•์ธ
            if centroid_shift < self.tol:
                print(f"Converged after {i+1} iterations")
                break

 

 

์ค‘์‹ฌ์  ์ดˆ๊ธฐํ™” ํ›„ ๋ฐ˜๋ณต ๊ณผ์ •:

  • 0๋‹จ๊ณ„์—์„œ๋Š” ์ฃผ์–ด์ง„ ์ดˆ๊ธฐ ํด๋Ÿฌ์Šคํ„ฐ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์‹œ๊ฐํ™”ํ•œ๋‹ค.
  • ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ๋ฅผ ํ• ๋‹นํ•˜๊ณ , ํด๋Ÿฌ์Šคํ„ฐ๋ณ„๋กœ ์ค‘์‹ฌ์ ์„ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.
  • ์—…๋ฐ์ดํŠธ ๋œ ์ดํ›„์— ๊ฐ ๋‹จ๊ณ„๋ณ„๋กœ plot ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค.
  • ์ˆ˜๋ ด ์—ฌ๋ถ€ ํ™•์ธ: ๊ฐ ๋ฐ˜๋ณต ๋‹จ๊ณ„์—์„œ ์ค‘์‹ฌ์ ์˜ ์ด๋™๋Ÿ‰์ด tol ๊ฐ’๋ณด๋‹ค ์ž‘์œผ๋ฉด ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ข…๋ฃŒํ•˜๊ณ  ์ˆ˜๋ ดํ–ˆ๋‹ค๊ณ  ํŒ๋‹จํ•œ๋‹ค.

 

 

์ด๋•Œ, ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด์„œ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ„์‚ฐํ•  ๋•Œ ๋‹ค์Œ์˜ ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•œ๋‹ค.

๊ฐ ์ค‘์‹ฌ์ ๊ณผ์˜ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ, ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์ค‘์‹ฌ์ ์˜ ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

    # ํŠน์ • ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ ์ค‘์‹ฌ์„ ์ฐพ์Œ
    def closest_centroid(self, x):
        distances = [euclidean_distance(x, centroid) for centroid in self.centroids]
        return np.argmin(distances) # ๊ฐ€์žฅ ๊ฑฐ๋ฆฌ๊ฐ€ ๊ฐ€๊นŒ์šด ํด๋Ÿฌ์Šคํ„ฐ์˜ ์ธ๋ฑ์Šค ๋ฐ˜ํ™˜

 

ํด๋Ÿฌ์Šคํ„ฐ๋ง ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”

๊ฐ ๋ฐ˜๋ณต ๋‹จ๊ณ„์—์„œ์˜ ํด๋Ÿฌ์Šคํ„ฐ๋ง ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค. ํŒŒ์ดํ”Œ๋กฏ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ด์šฉํ•˜์—ฌ ๊ฐ ์ขŒํ‘œ๊ฐ€ ์–ด๋–ค ํด๋Ÿฌ์Šคํ„ฐ์— ์†ํ•˜๋Š”์ง€ ์‰ฝ๊ฒŒ ์•Œ ์ˆ˜ ์žˆ๋„๋ก ์ƒ‰๊น”์„ ๊ตฌ๋ถ„ํ•˜์—ฌ ์‹œ๊ฐํ™”ํ•œ๋‹ค.

    # ๊ฐ ๋‹จ๊ณ„๋ณ„ ํด๋Ÿฌ์Šคํ„ฐ๋ง ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
    def plot_step(self, X, step):       
        
        cluster_color_set = ['blue','red']   

        # ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ ๋ณ„ ๊ทธ๋ฃน ์‹œ๊ฐํ™”
        for k, centroid in enumerate(self.centroids):
            plt.scatter(centroid[0], centroid[1], c=cluster_color_set[k], marker='x', s=200, label=f'Centroid {k}')
            plt.scatter(X[self.labels == k ][:,0], X[self.labels == k ][:,1], c=cluster_color_set[k])  
        
        plt.title(f"K-means Clustering (Step {step})")
        plt.legend()
        plt.show()

 

์‹คํ–‰

 

๋ฐ์ดํ„ฐ ํŒŒ์ผ์„ ๋กœ๋“œํ•˜๊ณ , K-means ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์‹คํ–‰ํ•œ๋‹ค.

def main():
    # ๋ฐ์ดํ„ฐ ํŒŒ์ผ ๋กœ๋“œ
    data_file = './data/hw2_data_2d.csv'
    centroids_file = './data/hw2_2d_init_centroids.csv'
    
    X, labels = read_data(data_file)  # ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ ๋กœ๋“œ
    init_centroids = read_centroid(centroids_file)  # ์ดˆ๊ธฐ ์ค‘์‹ฌ ๋กœ๋“œ

    # K-means ํด๋Ÿฌ์Šคํ„ฐ๋ง
    kmeans = KMeans(labels=labels ,n_clusters=2, init_centroids=init_centroids)
    kmeans.fit(X)

if __name__ == '__main__':
    main()

 

๊ฒฐ๊ณผ

 

์ด 7๋ฒˆ์˜ ๋ฐ˜๋ณต ํ›„์— ํด๋Ÿฌ์Šคํ„ฐ์˜ ์ค‘์‹ฌ์  ์ขŒํ‘œ๊ฐ€ tol(1e-5) ๋ฏธ๋งŒ์œผ๋กœ ์›€์ง์ด๊ฒŒ ๋˜๋ฏ€๋กœ ํด๋Ÿฌ์Šคํ„ฐ๋ง์ด ์™„๋ฃŒ๋˜์—ˆ๋‹ค๊ณ  ํŒ๋‹จํ•œ๋‹ค.