EM算法的思想利用到了极大似然法,首先必须对极大似然有所了解。
极大似然估计法
极大似然估计,简单来说就是通过抽取一部分样本,反推个体分布规律中的参数。比如从一个班抽取一部分同学,统计其身高,反推实际的高斯分布中的参数如均值$\mu$ 和标准差$\theta$。一般步骤就是:
(1)写出似然函数;
(2)对似然函数取对数,并整理;
(3)求导数,令导数为0,得到似然方程;
(4)解似然方程,得到的参数即为所求;
详细介绍如下:
EM算法
期望最大算法是一种从不完全数据或有数据丢失的数据集(存在隐含变量)中求解概率模型参数的最大似然估计方法。
EM的算法流程:
初始化分布参数θ;
重复以下步骤直到收敛:
E步骤:根据参数初始值或上一次迭代的模型参数来计算出隐性变量的后验概率,其实就是隐性变量的期望。作为隐藏变量的现估计值:
M步骤:将似然函数最大化以获得新的参数值:
这个不断的迭代,就可以得到使似然函数L(θ)最大化的参数θ了。那就得回答刚才的第二个问题了,它会收敛吗?
感性的说,因为下界不断提高,所以极大似然估计单调增加,那么最终我们会到达最大似然估计的最大值。理性分析的话,就会得到下面的东西:
具体如何证明的,看推导过程参考:Andrew Ng《The EM algorithm》
http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006936.html
EM算法用途
EM的应用
EM算法有很多的应用,最广泛的就是GMM混合高斯模型、聚类、HMM等等。具体可以参考JerryLead的cnblog中的Machine Learning专栏:
(EM算法)The EM Algorithm
混合高斯模型(Mixtures of Gaussians)和EM算法
K-means聚类算法
使用 EM工具包
直接使用sklearn中的工具包,获得GMM模型,先fit拟合,然后predict输出结果。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import pandas as pd import csv import matplotlib.pyplot as plt import seaborn as sns from sklearn.mixture import GaussianMixture from sklearn.preprocessing import StandardScaler
data_ori = pd.read_csv('./heros.csv', encoding = 'gb18030') features = [u'最大生命',u'生命成长',u'初始生命',u'最大法力', u'法力成长',u'初始法力',u'最高物攻',u'物攻成长',u'初始物攻',u'最大物防',u'物防成长',u'初始物防', u'最大每5秒回血', u'每5秒回血成长', u'初始每5秒回血', u'最大每5秒回蓝', u'每5秒回蓝成长', u'初始每5秒回蓝', u'最大攻速', u'攻击范围'] data = data_ori[features]
plt.rcParams['font.sans-serif']=['SimHei'] plt.rcParams['axes.unicode_minus']=False
corr = data[features].corr() plt.figure(figsize=(14,14))
sns.heatmap(corr, annot=True) plt.show()
features_remain = [u'最大生命', u'初始生命', u'最大法力', u'最高物攻', u'初始物攻', u'初始物攻', u'最大物防', u'初始物防', u'最大每5秒回血', u'最大每5秒回蓝', u'初始每5秒回蓝', u'最大攻速', u'攻击范围'] data = data_ori[features_remain] data[u'最大攻速'] = data[u'最大攻速'].apply(lambda x: float(x.strip('%'))/100) data[u'攻击范围']=data[u'攻击范围'].map({'远程':1,'近战':0})
ss = StandardScaler() data = ss.fit_transform(data)
gmm = GaussianMixture(n_components=5, covariance_type='full') gmm.fit(data)
prediction = gmm.predict(data) print(prediction)
data_ori.insert(0, '分组', prediction) data_ori.to_csv('./hero_out.csv', index=False, sep=',')
|