广义K均值聚类
Generalized K-Means Clustering

原始链接: https://github.com/derrickburns/generalized-kmeans-clustering

## 大规模数据科学聚类器:摘要 该项目提供了一个可扩展的库,用于K-Means聚类及其变体,并推广到多种Bregman散度。0.6.0版本引入了一个现代的、DataFrame原生API,与Spark ML集成,推荐用于所有新项目。为了向后兼容,维护了一个遗留的RDD API,但不再积极开发。 该库支持平方欧几里得、KL和L1(K-中位数)等散度,并提供Bisecting、X-Means、Soft/Fuzzy、Streaming、K-中位数和K-Medoids等变体。主要特性包括跨Spark/Scala版本的模型持久化、全面的测试(740+个测试)、确定性行为以及用于性能调优的详细诊断。 通过可配置的分配策略(auto、crossJoin、broadcastUDF、chunked)和广播诊断来解决扩展性问题,以防止内存不足错误。可以使用`log1p`或`epsilonShift`转换输入数据以满足散度域要求,并在拟合时进行自动验证。 该项目与Spark 3.4.x、3.5.x和4.0.x以及Scala 2.12和2.13兼容。还提供了一个PySpark包装器。遵循安全最佳实践,包括漏洞报告和自动依赖项更新。

## Spark 通用 K-Means 聚类 一个新的、可用于生产环境的 Apache Spark K-Means 库已发布,解决了 Spark 内置 MLlib 实现的局限性。MLlib 仅限于欧几里得距离,这对于概率分布、音频数据和计数数据等各种数据类型而言,在数学上常常不适用。这个新库允许使用 Bregman 散度——包括 KL 散度、Itakura-Saito 散度和 L1/曼哈顿距离——从而提供更准确和更有意义的聚类。 该库包含六种算法(GeneralizedKMeans、BisectingKMeans、XMeans、SoftKMeans、StreamingKMeans、KMedoids),并提供与 MLlib 兼容的即插即用 API。它具有广泛的测试、跨版本持久性、自动的可扩展性优化以及 Python 和 Scala API。 性能报告约为 870-3,400 点/秒,可扩展至数十亿个数据点。作者强调选择正确的距离度量(“散度”)的重要性,并指出确定最佳聚类数量 (K) 仍然是一个具有挑战性、特定于领域的难题。该项目可在 GitHub 上找到:[https://github.com/derrickburns/generalized-kmeans-clustering](https://github.com/derrickburns/generalized-kmeans-clustering)。
相关文章

原文

CI CodeQL License Scala 2.13 Scala 2.12 Spark 4.0 Spark 3.5

Security: This project follows security best practices. See SECURITY.md for vulnerability reporting and dependabot.yml for automated dependency updates.

🆕 DataFrame API (Spark ML) is the default. Version 0.6.0 introduces a modern, RDD-free DataFrame-native API with Spark ML integration. See DataFrame API Examples for end-to-end usage.

This project generalizes K-Means to multiple Bregman divergences and advanced variants (Bisecting, X-Means, Soft/Fuzzy, Streaming, K-Medians, K-Medoids). It provides:

  • A DataFrame/ML API (recommended), and
  • A legacy RDD API kept for backwards compatibility (archived below).
  • Multiple divergences: Squared Euclidean, KL, Itakura–Saito, L1/Manhattan (K-Medians), Generalized-I, Logistic-loss
  • Variants: Bisecting, X-Means (BIC/AIC), Soft K-Means, Structured-Streaming K-Means, K-Medoids (PAM/CLARA)
  • Scale: Tested on tens of millions of points in 700+ dimensions
  • Tooling: Scala 2.13 (primary) / 2.12, Spark 4.0.x / 3.5.x / 3.4.x
    • Spark 4.0.x: Scala 2.13 only (Scala 2.12 support dropped in Spark 4.0)
    • Spark 3.x: Both Scala 2.13 and 2.12 supported

Quick Start (DataFrame API)

Recommended for all new projects. The DataFrame API follows the Spark ML Estimator/Model pattern.

import org.apache.spark.ml.linalg.Vectors
import com.massivedatascience.clusterer.ml.GeneralizedKMeans

val df = spark.createDataFrame(Seq(
  Tuple1(Vectors.dense(0.0, 0.0)),
  Tuple1(Vectors.dense(1.0, 1.0)),
  Tuple1(Vectors.dense(9.0, 8.0)),
  Tuple1(Vectors.dense(8.0, 9.0))
)).toDF("features")

val gkm = new GeneralizedKMeans()
  .setK(2)
  .setDivergence("kl")              // "squaredEuclidean", "itakuraSaito", "l1", "generalizedI", "logistic"
  .setAssignmentStrategy("auto")    // "auto" | "crossJoin" (SE fast path) | "broadcastUDF" (general Bregman)
  .setMaxIter(20)

val model = gkm.fit(df)
val pred  = model.transform(df)
pred.show(false)

More recipes: see DataFrame API Examples.


Our comprehensive CI pipeline ensures quality across multiple dimensions:

View live CI results: CI Workflow Runs


Truth-linked to code, tests, and examples for full transparency:

Divergences Available: Squared Euclidean, KL, Itakura-Saito, L1/Manhattan, Generalized-I, Logistic Loss

All DataFrame API algorithms include:

  • ✅ Model persistence (save/load across Spark 3.4↔3.5↔4.0, Scala 2.12↔2.13)
  • ✅ Comprehensive test coverage (740 tests, 100% passing)
  • ✅ Executable documentation with assertions (8 examples validate correctness in CI)
  • ✅ Deterministic behavior (same seed → identical results)
  • ✅ CI validation on every commit

  • Spark: 3.5.1 default (override via -Dspark.version), 3.4.x tested
  • Scala: 2.13.14 (primary), 2.12.18 (cross-compiled)
  • Java: 17
libraryDependencies += "com.massivedatascience" %% "massivedatascience-clusterer" % "0.6.0"
  • Scala 2.13 primary; 3.5.x Spark default
  • DataFrame API implementations for: Bisecting, X-Means, Soft, Streaming, K-Medoids
  • K-Medians (L1) divergence support
  • PySpark wrapper + smoke test
  • Expanded examples & docs

Scaling & Assignment Strategy (important)

Different divergences require different assignment mechanics at scale:

  • Squared Euclidean (SE) fast path — expression/codegen route:
    1. Cross-join points with centers
    2. Compute squared distance column
    3. Prefer groupBy(rowId).min(distance) → join to pick argmin (scales better than window sorts)
    4. Requires a stable rowId; we provide a RowIdProvider.
  • General Bregman — broadcast + UDF route:
    • Broadcast the centers; compute argmin via a tight JVM UDF.
    • Broadcast ceiling: you'll hit executor/memory limits if k × dim is too large to broadcast.

Parameters

  • assignmentStrategy: StringParam = auto | crossJoin | broadcastUDF | chunked
    • auto (recommended): Chooses SE fast path when divergence == SE; otherwise selects between broadcastUDF and chunked based on k×dim size
    • crossJoin: Forces SE expression-based path (only works with Squared Euclidean)
    • broadcastUDF: Forces broadcast + UDF (works with any divergence, but may OOM on large k×dim)
    • chunked: Processes centers in chunks to avoid OOM (multiple data scans, but safe for large k×dim)
  • broadcastThreshold: IntParam (elements, not bytes)
    • Default: 200,000 elements (~1.5MB)
    • Heuristic ceiling for k × dim. If exceeded for non-SE divergences, AutoAssignment switches to chunked broadcast.
  • chunkSize: IntParam (for chunked strategy)
    • Default: 100 clusters per chunk
    • Controls how many centers are processed in each scan when using chunked broadcast

Broadcast Diagnostics

The library provides detailed diagnostics to help you tune performance and avoid OOM errors:

// Example: Large cluster configuration
val gkm = new GeneralizedKMeans()
  .setK(500)          // 500 clusters
  .setDivergence("kl") // Non-SE divergence
  // If your data has dim=1000, then k×dim = 500,000 elements

// AutoAssignment will log:
// [WARN] AutoAssignment: Broadcast size exceeds threshold
//   Current: k=500 × dim=1000 = 500000 elements ≈ 3.8MB
//   Threshold: 200000 elements ≈ 1.5MB
//   Overage: +150%
//
//   Using ChunkedBroadcast (chunkSize=100) to avoid OOM.
//   This will scan the data 5 times.
//
//   To avoid chunking overhead, consider:
//     1. Reduce k (number of clusters)
//     2. Reduce dimensionality (current: 1000 dimensions)
//     3. Increase broadcastThreshold (suggested: k=500 would need ~500000 elements)
//     4. Use Squared Euclidean divergence if appropriate (enables fast SE path)

When you see these warnings:

  • Chunked broadcast selected: Your configuration will work but may be slower due to multiple data scans. Follow the suggestions to improve performance.
  • Large broadcast warning (>100MB): Risk of executor OOM errors. Consider reducing k or dimensionality, or increasing executor memory.
  • No warning: Your configuration is well-sized for broadcasting.

Input Transforms & Interpretation

Some divergences (KL, IS) require positivity or benefit from stabilized domains.

  • inputTransform: StringParam = none | log1p | epsilonShift
  • shiftValue: DoubleParam (e.g., 1e-6) when epsilonShift is used.

Note: Cluster centers are learned in the transformed space. If you need original-space interpretation, apply the appropriate inverse (e.g., expm1) for reporting, understanding that this is an interpretive mapping, not a different optimum.


Domain Requirements & Validation

Automatic validation at fit time — Different divergences have different input domain requirements. The library automatically validates your data and provides actionable error messages if violations are found:

Divergence Domain Requirement Example Fix
squaredEuclidean Any finite values (x ∈ ℝ) None needed
l1 / manhattan Any finite values (x ∈ ℝ) None needed
kl Strictly positive (x > 0) Use log1p or epsilonShift transform
itakuraSaito Strictly positive (x > 0) Use log1p or epsilonShift transform
generalizedI Non-negative (x ≥ 0) Take absolute values or shift data
logistic Open interval (0 < x < 1) Normalize to [0,1] then use epsilonShift

What happens on validation failure:

When you call fit(), the library samples your data (first 1000 rows by default) and checks domain requirements. If violations are found, you'll see an actionable error message with:

  • The specific invalid value and its location (feature index)
  • Suggested fixes with example code
  • Transform options to map your data into the valid domain

Example validation error:

// This will fail for KL divergence (contains zero)
val df = spark.createDataFrame(Seq(
  Tuple1(Vectors.dense(1.0, 0.0)),  // Zero at index 1!
  Tuple1(Vectors.dense(2.0, 3.0))
)).toDF("features")

val kmeans = new GeneralizedKMeans()
  .setK(2)
  .setDivergence("kl")

kmeans.fit(df)  // ❌ Throws with actionable message

Error message you'll see:

kl divergence requires strictly positive values, but found: 0.0

The kl divergence is only defined for positive data.

Suggested fixes:
  - Use .setInputTransform("log1p") to transform data using log(1 + x), which maps [0, ∞) → [0, ∞)
  - Use .setInputTransform("epsilonShift") with .setShiftValue(1e-6) to add a small constant
  - Pre-process your data to ensure all values are positive
  - Consider using Squared Euclidean divergence (.setDivergence("squaredEuclidean")) which has no domain restrictions

Example:
  new GeneralizedKMeans()
    .setDivergence("kl")
    .setInputTransform("log1p")  // Transform to valid domain
    .setMaxIter(20)

How to fix domain violations:

  1. For KL/Itakura-Saito (requires x > 0):

    val kmeans = new GeneralizedKMeans()
      .setK(2)
      .setDivergence("kl")
      .setInputTransform("log1p")  // Maps [0, ∞) → [0, ∞) via log(1+x)
      .setMaxIter(20)
  2. For Logistic Loss (requires 0 < x < 1):

    // First normalize your data to [0, 1], then:
    val kmeans = new GeneralizedKMeans()
      .setK(2)
      .setDivergence("logistic")
      .setInputTransform("epsilonShift")
      .setShiftValue(1e-6)  // Shifts to (ε, 1-ε)
      .setMaxIter(20)
  3. For Generalized-I (requires x ≥ 0):

    // Pre-process to ensure non-negative values
    val df = originalDF.withColumn("features",
      udf((v: Vector) => Vectors.dense(v.toArray.map(math.abs)))
        .apply(col("features")))
    
    val kmeans = new GeneralizedKMeans()
      .setK(2)
      .setDivergence("generalizedI")
      .setMaxIter(20)

Validation scope:

  • Validates first 1000 rows by default (configurable in code)
  • Checks for NaN/Infinity in all divergences
  • Provides early failure with clear guidance before expensive computation
  • All DataFrame API estimators include validation: GeneralizedKMeans, BisectingKMeans, XMeans, SoftKMeans, CoresetKMeans

Bisecting K-Means — efficiency note

The driver maintains a cluster_id column. For each split: 1. Filter only the target cluster: df.where(col("cluster_id") === id) 2. Run the base learner on that subset (k=2) 3. Join back predictions to update only the touched rows

This avoids reshuffling the full dataset at every split.


Structured Streaming K-Means

Estimator/Model for micro-batch streams using the same core update logic.

  • initStrategy = pretrained | randomFirstBatch
  • pretrained: provide setInitialModel / setInitialCenters
  • randomFirstBatch: seed from the first micro-batch
  • State & snapshots: Each micro-batch writes centers to ${checkpointDir}/centers/latest.parquet for batch reuse.
  • StreamingGeneralizedKMeansModel.read(path) reconstructs a batch model from snapshots.

Models implement DefaultParamsWritable/Readable.

Layout

<path>/
  ├─ metadata/params.json
  ├─ centers/*.parquet          # (center_id, vector[, weight])
  └─ summary/*.json             # events, metrics (optional)

Compatibility

  • Save/Load verified across Spark 3.4.x ↔ 3.5.x in CI.
  • New params default safely on older loads; unknown params are ignored.

  • Package exposes GeneralizedKMeans, BisectingGeneralizedKMeans, SoftGeneralizedKMeans, StreamingGeneralizedKMeans, KMedoids, etc.
  • CI runs a spark-submit smoke test on local[*] with a non-SE divergence.

Legacy RDD API (Archived)

Status: Kept for backward compatibility. New development should use the DataFrame API. The material below documents the original RDD interfaces and helper objects. Some snippets show API signatures (placeholders) rather than runnable examples.

Quick Start (Legacy RDD API)

import com.massivedatascience.clusterer.KMeans
import org.apache.spark.mllib.linalg.Vectors

val data = sc.parallelize(Array(
  Vectors.dense(0.0, 0.0),
  Vectors.dense(1.0, 1.0),
  Vectors.dense(9.0, 8.0),
  Vectors.dense(8.0, 9.0)
))

val model = KMeans.train(
  data,
  runs = 1,
  k = 2,
  maxIterations = 20
)

The remainder of this section is an archived reference for the RDD API.

It includes: Bregman divergences, BregmanPoint/BregmanCenter, KMeansModel, clusterers, seeding, embeddings, iterative training, coreset helpers, and helper object builders. Code blocks that include ??? indicate signatures in the original design.

Open archived RDD documentation

(All of your original README RDD content goes here — exactly as provided in your message. For brevity in this chat, I’m not duplicating it again, but in your repo, place the full section here.)


  • Generalized K-Means Clustering
  • Quick Start (DataFrame API)
  • Feature Matrix
  • Installation / Versions
  • Scaling & Assignment Strategy
  • Input Transforms & Interpretation
  • Bisecting K-Means — efficiency note
  • Structured Streaming K-Means
  • Persistence (Spark ML)
  • Python (PySpark) wrapper
  • Legacy RDD API (Archived)

  • Please prefer PRs that target the DataFrame/ML path.
  • Add tests (including property-based where sensible) and update examples.
  • Follow Conventional Commits (feat:, fix:, docs:, refactor:, test:).

Apache 2.0


Notes for maintainers (can be removed later)

  • As you land more DF features, consider extracting the RDD material into LEGACY_RDD.md to keep the README short.
  • Keep the “Scaling & Assignment Strategy” section up-to-date when adding SE accelerations (Hamerly/Elkan/Yinyang) or ANN-assisted paths—mark SE-only and exact/approximate as appropriate.
联系我们 contact @ memedata.com