You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
8.1 KiB
230 lines
8.1 KiB
import { isEqual, uniq } from '@antv/util';
|
|
import { getAllProperties } from './utils/node-properties';
|
|
import { oneHot, getDistance } from './utils/data-preprocessing';
|
|
import Vector from './utils/vector';
|
|
import { DistanceType } from './types';
|
|
// 获取质心
|
|
var getCentroid = function getCentroid(distanceType, allPropertiesWeight, index) {
|
|
var centroid = [];
|
|
switch (distanceType) {
|
|
case DistanceType.EuclideanDistance:
|
|
centroid = allPropertiesWeight[index];
|
|
break;
|
|
default:
|
|
centroid = [];
|
|
break;
|
|
}
|
|
return centroid;
|
|
};
|
|
/**
|
|
* k-means算法 根据节点之间的距离将节点聚类为K个簇
|
|
* @param data 图数据
|
|
* @param k 质心(聚类中心)个数
|
|
* @param propertyKey 属性的字段名
|
|
* @param involvedKeys 参与计算的key集合
|
|
* @param uninvolvedKeys 不参与计算的key集合
|
|
* @param distanceType 距离类型 默认节点属性的欧式距离
|
|
*/
|
|
var kMeans = function kMeans(data, k, propertyKey, involvedKeys, uninvolvedKeys, distanceType) {
|
|
if (k === void 0) {
|
|
k = 3;
|
|
}
|
|
if (propertyKey === void 0) {
|
|
propertyKey = undefined;
|
|
}
|
|
if (involvedKeys === void 0) {
|
|
involvedKeys = [];
|
|
}
|
|
if (uninvolvedKeys === void 0) {
|
|
uninvolvedKeys = ['id'];
|
|
}
|
|
if (distanceType === void 0) {
|
|
distanceType = DistanceType.EuclideanDistance;
|
|
}
|
|
var _a = data.nodes,
|
|
nodes = _a === void 0 ? [] : _a,
|
|
_b = data.edges,
|
|
edges = _b === void 0 ? [] : _b;
|
|
var defaultClusterInfo = {
|
|
clusters: [{
|
|
id: "0",
|
|
nodes: nodes
|
|
}],
|
|
clusterEdges: []
|
|
};
|
|
// 距离类型为欧式距离且没有属性时,直接return
|
|
if (distanceType === DistanceType.EuclideanDistance && !nodes.every(function (node) {
|
|
return node.hasOwnProperty(propertyKey);
|
|
})) {
|
|
return defaultClusterInfo;
|
|
}
|
|
// 所有节点属性集合
|
|
var properties = [];
|
|
// 所有节点属性one-hot特征向量集合
|
|
var allPropertiesWeight = [];
|
|
if (distanceType === DistanceType.EuclideanDistance) {
|
|
properties = getAllProperties(nodes, propertyKey);
|
|
allPropertiesWeight = oneHot(properties, involvedKeys, uninvolvedKeys);
|
|
}
|
|
if (!allPropertiesWeight.length) {
|
|
return defaultClusterInfo;
|
|
}
|
|
var allPropertiesWeightUniq = uniq(allPropertiesWeight.map(function (item) {
|
|
return item.join('');
|
|
}));
|
|
// 当输入节点数量或者属性集合的长度小于k时,k调整为其中最小的值
|
|
var finalK = Math.min(k, nodes.length, allPropertiesWeightUniq.length);
|
|
// 记录节点的原始index,与allPropertiesWeight对应
|
|
for (var i = 0; i < nodes.length; i++) {
|
|
nodes[i].originIndex = i;
|
|
}
|
|
// 初始化质心(聚类中心)
|
|
var centroids = [];
|
|
var centroidIndexList = [];
|
|
var clusters = [];
|
|
for (var i = 0; i < finalK; i++) {
|
|
if (i === 0) {
|
|
// 随机选取质心(聚类中心)
|
|
var randomIndex = Math.floor(Math.random() * nodes.length);
|
|
switch (distanceType) {
|
|
case DistanceType.EuclideanDistance:
|
|
centroids[i] = allPropertiesWeight[randomIndex];
|
|
break;
|
|
default:
|
|
centroids[i] = [];
|
|
break;
|
|
}
|
|
centroidIndexList.push(randomIndex);
|
|
clusters[i] = [nodes[randomIndex]];
|
|
nodes[randomIndex].clusterId = String(i);
|
|
} else {
|
|
var maxDistance = -Infinity;
|
|
var maxDistanceNodeIndex = 0;
|
|
var _loop_1 = function _loop_1(m) {
|
|
if (!centroidIndexList.includes(m)) {
|
|
var totalDistance = 0;
|
|
for (var j = 0; j < centroids.length; j++) {
|
|
// 求节点到质心的距离(默认节点属性的欧式距离)
|
|
var distance = 0;
|
|
switch (distanceType) {
|
|
case DistanceType.EuclideanDistance:
|
|
distance = getDistance(allPropertiesWeight[nodes[m].originIndex], centroids[j], distanceType);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
totalDistance += distance;
|
|
}
|
|
// 节点到各质心的平均距离(默认欧式距离)
|
|
var avgDistance = totalDistance / centroids.length;
|
|
// 记录到已有质心最远的的距离和节点索引
|
|
if (avgDistance > maxDistance && !centroids.find(function (centroid) {
|
|
return isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodes[m].originIndex));
|
|
})) {
|
|
maxDistance = avgDistance;
|
|
maxDistanceNodeIndex = m;
|
|
}
|
|
}
|
|
};
|
|
// 选取与已有质心平均距离最远的点做为新的质心
|
|
for (var m = 0; m < nodes.length; m++) {
|
|
_loop_1(m);
|
|
}
|
|
centroids[i] = getCentroid(distanceType, allPropertiesWeight, maxDistanceNodeIndex);
|
|
centroidIndexList.push(maxDistanceNodeIndex);
|
|
clusters[i] = [nodes[maxDistanceNodeIndex]];
|
|
nodes[maxDistanceNodeIndex].clusterId = String(i);
|
|
}
|
|
}
|
|
var iterations = 0;
|
|
while (true) {
|
|
for (var i = 0; i < nodes.length; i++) {
|
|
var minDistanceIndex = 0;
|
|
var minDistance = Infinity;
|
|
if (!(iterations === 0 && centroidIndexList.includes(i))) {
|
|
for (var j = 0; j < centroids.length; j++) {
|
|
// 求节点到质心的距离(默认节点属性的欧式距离)
|
|
var distance = 0;
|
|
switch (distanceType) {
|
|
case DistanceType.EuclideanDistance:
|
|
distance = getDistance(allPropertiesWeight[i], centroids[j], distanceType);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
// 记录节点最近的质心的索引
|
|
if (distance < minDistance) {
|
|
minDistance = distance;
|
|
minDistanceIndex = j;
|
|
}
|
|
}
|
|
// 从原来的类别删除节点
|
|
if (nodes[i].clusterId !== undefined) {
|
|
for (var n = clusters[Number(nodes[i].clusterId)].length - 1; n >= 0; n--) {
|
|
if (clusters[Number(nodes[i].clusterId)][n].id === nodes[i].id) {
|
|
clusters[Number(nodes[i].clusterId)].splice(n, 1);
|
|
}
|
|
}
|
|
}
|
|
// 将节点划分到距离最小的质心(聚类中心)所对应的类中
|
|
nodes[i].clusterId = String(minDistanceIndex);
|
|
clusters[minDistanceIndex].push(nodes[i]);
|
|
}
|
|
}
|
|
// 是否存在质心(聚类中心)移动
|
|
var centroidsEqualAvg = false;
|
|
for (var i = 0; i < clusters.length; i++) {
|
|
var clusterNodes = clusters[i];
|
|
var totalVector = new Vector([]);
|
|
for (var j = 0; j < clusterNodes.length; j++) {
|
|
totalVector = totalVector.add(new Vector(allPropertiesWeight[clusterNodes[j].originIndex]));
|
|
}
|
|
// 计算每个类别的均值向量
|
|
var avgVector = totalVector.avg(clusterNodes.length);
|
|
// 如果均值向量不等于质心向量
|
|
if (!avgVector.equal(new Vector(centroids[i]))) {
|
|
centroidsEqualAvg = true;
|
|
// 移动/更新每个类别的质心(聚类中心)到该均值向量
|
|
centroids[i] = avgVector.getArr();
|
|
}
|
|
}
|
|
iterations++;
|
|
// 如果每个节点都归属了类别,且不存在质心(聚类中心)移动或者迭代次数超过1000,则停止
|
|
if (nodes.every(function (node) {
|
|
return node.clusterId !== undefined;
|
|
}) && centroidsEqualAvg || iterations >= 1000) {
|
|
break;
|
|
}
|
|
}
|
|
// get the cluster edges
|
|
var clusterEdges = [];
|
|
var clusterEdgeMap = {};
|
|
edges.forEach(function (edge) {
|
|
var _a, _b;
|
|
var source = edge.source,
|
|
target = edge.target;
|
|
var sourceClusterId = (_a = nodes.find(function (node) {
|
|
return node.id === source;
|
|
})) === null || _a === void 0 ? void 0 : _a.clusterId;
|
|
var targetClusterId = (_b = nodes.find(function (node) {
|
|
return node.id === target;
|
|
})) === null || _b === void 0 ? void 0 : _b.clusterId;
|
|
var newEdgeId = "".concat(sourceClusterId, "---").concat(targetClusterId);
|
|
if (clusterEdgeMap[newEdgeId]) {
|
|
clusterEdgeMap[newEdgeId].count++;
|
|
} else {
|
|
var newEdge = {
|
|
source: sourceClusterId,
|
|
target: targetClusterId,
|
|
count: 1
|
|
};
|
|
clusterEdgeMap[newEdgeId] = newEdge;
|
|
clusterEdges.push(newEdge);
|
|
}
|
|
});
|
|
return {
|
|
clusters: clusters,
|
|
clusterEdges: clusterEdges
|
|
};
|
|
};
|
|
export default kMeans;
|