在统计计算中,最大期望(EM)算法是在概率(probabilistic)模型中寻找参数最大似然估计或者最大后验估计的算法,其中概率模型依赖于无法观测的隐藏变量(Latent Variable)。最大期望经常用在机器学习和计算机视觉的数据聚类(Data Clustering)领域。最大期望算法经过两个步骤交替进行计算,第一步是计算期望(E),利用对隐藏变量的现有估计值,计算其最大似然估计值;第二步是最大化(M),最大化在 E 步上求得的最大似然值来计算参数的值。M 步上找到的参数估计值被用于下一个 E 步计算中,这个过程不断交替进行。
defreInit(self, K=2, init='random'): indices = np.random.choice(self.dataSet.shape[0], K, replace=False) self.centerPos = self.dataSet[indices] self.K = K
deftrain(self, n=1000, plt=None): for i inrange(n): # 对每个点,找到距离它最近的中心归类 classes = [[] for k inrange(self.K)] for data in self.dataSet: distanceWithCenter = np.linalg.norm(data - self.centerPos, ord=2, axis=1, keepdims=False) nearCenterI = np.argmin(distanceWithCenter) classes[nearCenterI].append(data) classes = [np.array(c) for c in classes] classes = np.array(classes)
# 对每个类,找出其新中心 self.centerPos = [c.mean(axis=0) for c in classes]
print(i)
if plt isnotNone: plt.subplot(1, 2, 2, frameon=False) plt.cla() plt.scatter(classes[0].T[0], classes[0].T[1], marker='o', c='r', alpha=0.7, label='red') plt.scatter(classes[1].T[0], classes[1].T[1], marker='o', c='b', alpha=0.7, label='blue') plt.scatter(classes[2].T[0], classes[2].T[1], marker='o', c='g', alpha=0.7, label='green') for center in self.centerPos: plt.scatter(center[0], center[1], marker='o', c='black', alpha=0.7, label='center') plt.legend() plt.pause(0.01)
deftrain(self, n=1000, plt=None): N = self.dataSet.shape[0] for i inrange(n): # E步,根据当前分布计算每个点属于各个分布的概率 gammas = [] for data in self.dataSet: sumGamma = 0 gamma = [] # 对每个类别,计算这个点的pdf m = 0 for mean, sigma in self.centerNormal: g = self.PiK[m] * multivariate_normal.pdf(data, mean=mean, cov=sigma) sumGamma += g gamma.append(g) m += 1 # 归一化概率 gamma = np.array(gamma) gamma /= sumGamma gammas.append(gamma) gammas = np.array(gammas) self.PiK = gammas.sum(axis=0) / gammas.sum()
# M步根据每个点的分布情况更新参数 mu_new = gammas.T.dot(self.dataSet) / np.array([self.PiK, self.PiK]).T / N sigmas_new = [] for j inrange(self.K): m = 0 sigma_new = np.array([[0.0, 0.0], [0.0, 0.0]]) for data in self.dataSet: sigma_new += (data - mu_new[j])[None].T.dot((data - mu_new[j])[None]) * gammas[m][j] m += 1 sigmas_new.append(sigma_new) sigmas_new /= self.PiK[:, None, None] * N
self.centerNormal = [(mu_new[m], sigmas_new[m]) for m inrange(self.K)]
# print(mu_new) # print(sigmas_new)
print('\r trun:%d/%d' % (i + 1, n), end='')
if plt isnotNone: m = 0 plt.subplot(1, 2, 2, frameon=False) plt.cla() for data in self.dataSet: plt.scatter(data[0], data[1], marker='o', color=gammas[m], alpha=0.7) m += 1 for k inrange(self.K): plt.scatter(self.centerNormal[k][0][0], self.centerNormal[k][0][1], marker='o', c='black', alpha=0.7) plt.pause(0.01)
:return data: shape(M, N), M 个 N 维服从高斯分布的样本 :return Gaussian: 高斯分布概率密度函数 """ mean = np.array(m) # or np.zeros(N) # 均值矩阵,每个维度的均值都为 m cov = np.array(sigma) # or np.eye(N) # 协方差矩阵,每个维度的方差都为 sigma
# 产生 N 维高斯分布数据 data = np.random.multivariate_normal(mean, cov, M,) # N 维数据高斯分布概率密度函数 Gaussian = multivariate_normal(mean=mean, cov=cov)