k-means聚类算法
聚类分析算法的一个java代码,我的项目中应用了这个代码。
package com.methol.util;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
//K-means算法实现
public class KMeans {
// 聚类的数目
public final static int ClassCount = 4;
// 样本数目(测试集)
public static int InstanceNumber = 100;
// 样本属性数目(测试)
public final static int FieldCount = 1;
// 设置异常点阈值参数(每一类初始的最小数目为InstanceNumber/ClassCount^t)
public final static double t = 2.0;
// 存放数据的矩阵
public static double[][] data;
// 每个类的均值中心
public static double[][] classData;
// 噪声集合索引
public static ArrayList<Integer> noises;
// 存放每次变换结果的矩阵
public static ArrayList<ArrayList<Integer>> result;
//存放每个属性的最大值
public static double[] classmax = new double[FieldCount];
// 构造函数,初始化
public KMeans() {
// 最后一位用来储存结果
data = new double[InstanceNumber][FieldCount + 1];
classData = new double[ClassCount][FieldCount];
result = new ArrayList<ArrayList<Integer>>(ClassCount);
noises = new ArrayList<Integer>();
}
/**
* 主函数入口 测试集的文件名称为“测试集.data”,其中有1000*57大小的数据 每一行为一个样本,有57个属性 主要分为两个步骤 1.读取数据
* 2.进行聚类 最后统计运行时间和消耗的内存
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
long startTime = System.currentTimeMillis();
KMeans cluster = new KMeans();
cluster.InstanceNumber = 100;
// 读取数据
//cluster.readData("D:/test.txt");
//随机产生数据
for (int i = 0; i < InstanceNumber; i++) {
data[i][0] = (double) Math.random();
data[i][0] = data[i][0] * 100;
System.out.println(data[i][0]);
}
// 聚类过程
cluster.cluster();
// 输出结果
cluster.printResult("clusterResult.data");
long endTime = System.currentTimeMillis();
System.out.println("Total Time:" + (endTime - startTime) + "ms");
System.out.println("Memory Consuming:"
+ (double) (Runtime.getRuntime().totalMemory() - Runtime
.getRuntime().freeMemory()) / 1000000 + "MB");
System.out.println("聚类中心:");
for (int i = 0; i < ClassCount; i++) {
System.out.println(classData[i][0] * classmax[0]);
//data[i][0] = (double) (Math.random()*100);
}
for (ArrayList<Integer> i : result) {
for (Integer integer : i) {
System.out.print(integer + "\t");
}
System.out.println("数目:" + i.size());
}
for (ArrayList<Integer> i : result) {
for (Integer integer : i) {
System.out.print(data[integer][0] + "\t");
}
System.out.println("数目:" + i.size());
}
System.out.println("noises:");
for (Integer noise : noises) {
System.out.println(noise);
}
}
/**
* 读取测试集的数据
*
* @param trainingFileName 测试集文件名
*/
public void readData(String trainingFileName) {
try {
FileReader fr = new FileReader(trainingFileName);
BufferedReader br = new BufferedReader(fr);
// 存放数据的临时变量
String lineData = null;
String[] splitData = null;
int line = 0;
// 按行读取
while (br.ready()) {
// 得到原始的字符串
lineData = br.readLine();
splitData = lineData.split(",");
// 转化为数据
// System.out.println("length:"+splitData.length);
if (splitData.length > 1) {
for (int i = 0; i < splitData.length; i++) {
// System.out.println(splitData[i]);
// System.out.println(splitData[i].getClass());
if (splitData[i].startsWith("Iris-setosa")) {
data[line][i] = (double) 1.0;
} else if (splitData[i].startsWith("Iris-versicolor")) {
data[line][i] = (double) 2.0;
} else if (splitData[i].startsWith("Iris-virginica")) {
data[line][i] = (double) 3.0;
} else { // 将数据截取之后放进数组
data[line][i] = Double.parseDouble(splitData[i]);
}
}
line++;
}
}
System.out.println(line);
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 聚类过程,主要分为两步 1.循环找初始点 2.不断调整直到分类不再发生变化
*/
public void cluster() {
// 数据标准化处理
normalize();
// 标记是否需要重新找初始点
boolean needUpdataInitials = true;
// 找初始点的迭代次数
int times = 1;
// 找初始点
while (needUpdataInitials) {
needUpdataInitials = false;
result.clear();
//System.out.println("Find Initials Iteration" + (times++) + "time(s)");
// 一次找初始点的尝试和根据初始点的分类
findInitials();
firstClassify();
// 如果某个分类的数目小于特定的阈值,则认为这个分类中的所有样本都是噪声点
// 需要重新找初始点
for (int i = 0; i < result.size(); i++) {
if (result.get(i).size() < InstanceNumber
/ Math.pow(ClassCount, t)) {
needUpdataInitials = true;
noises.addAll(result.get(i));
}
}
}
// 找到合适的初始点后
// 不断的调整均值中心和分类,直到不再发生任何变化
Adjust();
// //把结果存入数组answer中
// for (int i = 0; i < ClassCount; i++) {
// KMeans.answer[i] = classData[i][0];
// }
}
/**
* 对数据进行归一化 1.找每一个属性的最大值 2.对某个样本的每个属性除以其最大值
*/
public void normalize() {
// 找最大值
for (int i = 0; i < InstanceNumber; i++) {
for (int j = 0; j < FieldCount; j++) {
if (data[i][j] > classmax[j])
classmax[j] = data[i][j];
}
}
// 归一化
for (int i = 0; i < InstanceNumber; i++) {
for (int j = 0; j < FieldCount; j++) {
data[i][j] = data[i][j] / classmax[j];
}
}
}
// 关于初始向量的一次找寻尝试
public void findInitials() {
// a,b为标志距离最远的两个向量的索引
int i, j, a, b;
i = j = a = b = 0;
// 最远距离
double maxDis = 0;
// 已经找到的初始点个数
int alreadyCls = 2;
// 存放已经标记为初始点的向量索引
ArrayList<Integer> initials = new ArrayList<Integer>();
// 从两个开始
for (; i < InstanceNumber; i++) {
// 噪声点
if (noises.contains(i))
continue;
// long startTime = System.currentTimeMillis();
j = i + 1;
for (; j < InstanceNumber; j++) {
// 噪声点
if (noises.contains(j))
continue;
// 找出最大的距离并记录下来
double newDis = calDis(data[i], data[j]);
if (maxDis < newDis) {
a = i;
b = j;
maxDis = newDis;
}
}
// long endTime = System.currentTimeMillis();
// System.out.println(i +
// "Vector Caculation Time:"+(endTime-startTime)+"ms");
}
// 将前两个初始点记录下来
initials.add(a);
initials.add(b);
classData[0] = data[a];
classData[1] = data[b];
// 在结果中新建存放某样本索引的对象,并把初始点添加进去
ArrayList<Integer> resultOne = new ArrayList<Integer>();
ArrayList<Integer> resultTwo = new ArrayList<Integer>();
resultOne.add(a);
resultTwo.add(b);
result.add(resultOne);
result.add(resultTwo);
// 找到剩余的几个初始点
while (alreadyCls < ClassCount) {
i = j = 0;
double maxMin = 0;
int newClass = -1;
// 找最小值中的最大值
for (; i < InstanceNumber; i++) {
double min = 0;
double newMin = 0;
// 找和已有类的最小值
if (initials.contains(i))
continue;
// 噪声点去除
if (noises.contains(i))
continue;
for (j = 0; j < alreadyCls; j++) {
newMin = calDis(data[i], classData[j]);
if (min == 0 || newMin < min)
min = newMin;
}
// 新最小距离较大
if (min > maxMin) {
maxMin = min;
newClass = i;
}
}
// 添加到均值集合和结果集合中
// System.out.println("NewClass"+newClass);
initials.add(newClass);
//System.err.println("newClass:"+newClass);
//System.err.println("alreadyCls:"+alreadyCls);
classData[alreadyCls++] = data[newClass];
ArrayList<Integer> rslt = new ArrayList<Integer>();
rslt.add(newClass);
result.add(rslt);
}
}
// 第一次分类
public void firstClassify() {
// 根据初始向量分类
for (int i = 0; i < InstanceNumber; i++) {
double min = 0f;
int clsId = -1;
for (int j = 0; j < classData.length; j++) {
// 欧式距离
double newMin = calDis(classData[j], data[i]);
if (clsId == -1 || newMin < min) {
clsId = j;
min = newMin;
}
}
// 本身不再添加
if (!result.get(clsId).contains(i))
result.get(clsId).add(i);
}
}
// 迭代分类,直到各个类的数据不再变化
public void Adjust() {
// 记录是否发生变化
boolean change = true;
// 循环的次数
int times = 1;
while (change) {
// 复位
change = false;
//System.out.println("Adjust Iteration" + (times++) + "time(s)");
// 重新计算每个类的均值
for (int i = 0; i < ClassCount; i++) {
// 原有的数据
ArrayList<Integer> cls = result.get(i);
// 新的均值
double[] newMean = new double[FieldCount];
// 计算均值
for (Integer index : cls) {
for (int j = 0; j < FieldCount; j++)
newMean[j] += data[index][j];
}
for (int j = 0; j < FieldCount; j++)
newMean[j] /= cls.size();
if (!compareMean(newMean, classData[i])) {
classData[i] = newMean;
change = true;
}
}
// 清空之前的数据
for (ArrayList<Integer> cls : result)
cls.clear();
// 重新分配
for (int i = 0; i < InstanceNumber; i++) {
double min = 0f;
int clsId = -1;
for (int j = 0; j < classData.length; j++) {
double newMin = calDis(classData[j], data[i]);
if (clsId == -1 || newMin < min) {
clsId = j;
min = newMin;
}
}
data[i][FieldCount] = clsId;
result.get(clsId).add(i);
}
// 测试聚类效果(训练集)
// for(int i = 0;i < ClassCount;i++){
// int positives = 0;
// int negatives = 0;
// ArrayList<Integer> cls = result.get(i);
// for(Integer instance:cls)
// if (data[instance][FieldCount - 1] == 1f)
// positives ++;
// else
// negatives ++;
// System.out.println(" " + i + " Positive: " + positives +
// " Negatives: " + negatives);
// }
// System.out.println();
}
}
/**
* 计算a样本和b样本的欧式距离作为不相似度
*
* @param aVector 样本a
* @param bVector 样本b
* @return 欧式距离长度
*/
private double calDis(double[] aVector, double[] bVector) {
double dis = 0;
int i = 0;
/* 最后一个数据在训练集中为结果,所以不考虑 */
for (; i < aVector.length; i++)
dis += Math.pow(bVector[i] - aVector[i], 2);
dis = Math.pow(dis, 0.5);
return (double) dis;
}
/**
* 判断两个均值向量是否相等
*
* @param a 向量a
* @param b 向量b
* @return 相等返回true
*/
private boolean compareMean(double[] a, double[] b) {
if (a.length != b.length)
return false;
for (int i = 0; i < a.length; i++) {
if (a[i] > 0 && b[i] > 0 && a[i] != b[i]) {
return false;
}
}
return true;
}
/**
* 将结果输出到一个文件中
*
* @param fileName 文件名
*/
public void printResult(String fileName) {
FileWriter fw = null;
BufferedWriter bw = null;
try {
fw = new FileWriter(fileName);
bw = new BufferedWriter(fw);
// 写入文件
for (int i = 0; i < InstanceNumber; i++) {
bw.write(String.valueOf(data[i][FieldCount]).substring(0, 1));
bw.newLine();
}
// 统计每类的数目,打印到控制台
for (int i = 0; i < ClassCount; i++) {
System.out.println("第" + (i + 1) + "类数目: "
+ result.get(i).size());
}
} catch (IOException e) {
e.printStackTrace();
} finally {
// 关闭资源
if (bw != null)
try {
bw.close();
} catch (IOException e) {
e.printStackTrace();
}
if (fw != null)
try {
fw.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}