K-Means++代码实现
数据集
https://download.csdn.net/download/qq_43629083/87246495
import pandas as pd
import numpy as np
import random
import math
%matplotlib inline
from matplotlib import pyplot as plt# 按文件名读取整个文件
data = pd.read_csv('data.csv')
class MyKmeansPlusPlus:def __init__(self, k, max_iter = 10):self.k = k# 最大迭代次数self.max_iter = max_iter# 训练集self.data_set = None# 结果集self.labels = None'''计算两点间的欧拉距离'''def euler_distance(self, point1, point2):distance = 0.0for a, b in zip(point1, point2):distance += math.pow(a - b, 2)return math.sqrt(distance)'''计算样本中的每一个样本点与已经初始化的聚类中心之间的距离,并选择其中最短的距离'''def nearest_distance(self, point, cluster_centers):min_distance = math.infdim = np.shape(cluster_centers)[0]for i in range(dim):# 计算point与每个聚类中心的距离distance = self.euler_distance(point, cluster_centers[i])# 选择最短距离if distance < min_distance:min_distance = distancereturn min_distance'''初始化k个聚类中心'''def get_centers(self):dim_m, dim_n = np.shape(self.data_set)cluster_centers = np.array(np.zeros(shape = (self.k, dim_n)))#随机初始化第一个聚类中心点index = np.random.randint(0, dim_m)cluster_centers[0] = self.data_set[index]# 初始化一个距离序列distances = [0.0 for _ in range(dim_m)]for i in range(1, self.k):print("i = ", i)sum_all = 0.0for j in range(dim_m):# 对每一个样本找到最近的聚类中心点distances[j] = self.nearest_distance(self.data_set[j], cluster_centers[0:i])# 将所有最短距离相加sum_all += distances[j]# 取得sum_all之间的随机值sum_all *= random.random()# 以概率获得距离最远的样本点作为聚类中心for id, dist in enumerate(distances):sum_all -= distif sum_all > 0:continuecluster_centers[i] = self.data_set[id]break;return cluster_centers'''确定非中心点与哪个中心点最近'''def get_closest_index(self, point, centers):# 初始值设为最大min_dist = math.inflabel = -1# enumerate() 函数同时列出数据和数据下标for i, center in enumerate(centers):dist = self.euler_distance(center, point)if dist < min_dist:min_dist = distlabel = ireturn label'''更新中心点'''def update_centers(self):# k类点分别存points_label = [[] for i in range(self.k)]for i, label in enumerate(self.labels):points_label[label].append(self.data_set[i])centers = []for i in range(self.k):centers.append(np.mean(points_label[i], axis = 0))return centers'''判断是否停止迭代,新中心点与旧中心点一致或者达到设置的迭代最大值则停止'''def stop_iter(self, old_centers, centers, step):if step > self.max_iter:return Truereturn np.array_equal(old_centers, centers)'''模型训练'''def fit(self, data_set):self.data_set = data_set.drop(['labels'], axis = 1)self.data_set = np.array(self.data_set)point_num = np.shape(data_set)[0]# 初始化结果集self.labels = data_set.loc[:, 'labels']self.labels = np.array(self.labels)# 初始化k个聚类中心点centers = self.get_centers()# 保存上一次迭代的中心点old_centers = []# 当前迭代次数step = 0flag = Falsewhile not flag:# 存储 旧的中心点old_centers = np.copy(centers)# 迭代次数+1step += 1print("current iteration: ", step)print("current centers: ", old_centers)# 本次迭代 各个点所属类别(即该点与哪个中心点最近)for i, point in enumerate(self.data_set):self.labels[i] = self.get_closest_index(point, centers)# 更新中心点centers = self.update_centers()# 迭代是否停止的标志flag = self.stop_iter(old_centers, centers, step)centers = np.array(centers)fig = plt.figure()label0 = plt.scatter(self.data_set[:, 0][self.labels == 0], self.data_set[:, 1][self.labels == 0])label1 = plt.scatter(self.data_set[:, 0][self.labels == 1], self.data_set[:, 1][self.labels == 1])label2 = plt.scatter(self.data_set[:, 0][self.labels == 2], self.data_set[:, 1][self.labels == 2])plt.scatter(old_centers[:, 0], old_centers[:, 1], marker='^', edgecolor='black', s=128)plt.title('labeled data')plt.xlabel('V1')plt.ylabel('V2')plt.legend((label0, label1, label2), ('label0', 'label1', 'label2'))plt.show()
myKmeansPP = MyKmeansPlusPlus(3)
myKmeansPP.fit(data)
current iteration: 1
current centers:
[[55.97659 75.71833 ]
[43.75808 67.45812 ]
[71.72321 -7.872746]]
current iteration: 2
current centers:
[[55.83404759 70.21560931]
[30.35261288 47.71518861]
[50.15798861 -5.34769581]]
current iteration: 3
current centers:
[[47.66230967 65.1238036 ]
[22.93488 39.05383154]
[52.52023009 -6.18734425]]
current iteration: 4
current centers:
[[42.96329079 61.70702396]
[12.28521822 20.36196405]
[63.73622886 -9.02914858]]
current iteration: 5
current centers:
[[ 40.8388755 59.95703427]
[ 9.62033389 11.15366963]
[ 69.77599323 -10.09654797]]