Skip to content

Commit 36445b2

Browse files
authored
[ENH] Added results, scripts and notebook for KASBA (#440)
* added results, scripts and notebook for KASBA * updated name * KASBA * type fixes * fixed type check
1 parent a2049d6 commit 36445b2

30 files changed

+3365
-4
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Files for clustering publications."""
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# 📘 KASBA: k-means Accelerated Stochastic Subgradient Barycentre Averaging
2+
**Official Repository for the KASBA Time Series Clustering Paper**
3+
4+
This repository accompanies the paper:
5+
6+
> **Rock the KASBA: Blazingly Fast and Accurate Time Series Clustering**
7+
>
8+
> https://arxiv.org/abs/2411.17838
9+
10+
KASBA is a $k$-means clustering algorithm that uses the Move-Split-Merge (MSM) elastic distance at all stages of clustering, applies a randomised stochastic subgradient descent to find barycentre centroids, links each stage of clustering to accelerate convergence and exploits the metric property of MSM distance to avoid a large proportion of distance calculations. It is a versatile and scalable clusterer designed for real-world TSCL applications. It allows practitioners to balance runtime and clustering performance when similarity is best measured by an elastic distance.
11+
12+
KASBA delivers state-of-the-art clustering performance while achieving 1–3 orders of magnitude speedups over existing elastic distance–based k-means algorithms.
13+
14+
This repository contains the exact model configurations, experiment scripts, and visualisation tools used to produce the results in the paper.
15+
16+
---
17+
18+
## 📁 Repository Structure
19+
20+
kasba/
21+
├── README.md # This file
22+
├── __init__.py
23+
├── _utils.py # Internal utilities used across the project
24+
├── _model_configuration.py # Definitions of all models and configurations used in experiments
25+
├── _experiment_script.py # Script used to run experiments on datasets
26+
├── kasba.ipynb # Notebook demonstrating how to run KASBA
27+
├── result_visualisation.ipynb # Notebook for generating CD diagrams, MCM plots, etc.
28+
└── results/ # Raw CSV result files used in the paper
29+
└── combined # Subfolder for combined results
30+
└── k-shape-compare # Subfolders results in section 5.4
31+
└── section-5.1 # Subfolders results in section 5.1
32+
└── train-test # Subfolders for train and test results
33+
└── section-5.1 # Subfolders results in section 5.1
34+
└── section-5.2 # Subfolders results in section 5.2
35+
└── section-5.3 # Subfolders results in section 5.3
36+
37+
38+
## 🚀 Getting Started
39+
40+
### Install dependencies
41+
42+
Create and activate a virtual environment from tsml-eval:
43+
44+
python3 -m venv venv
45+
source venv/bin/activate
46+
pip install -e .
47+
48+
If you are reading this message you will have to install a specific branch
49+
of aeon while we wait for a new release. Run the following command to install:
50+
51+
pip uninstall aeon
52+
pip install git+https://github.com/aeon-toolkit/aeon@kasba-results#egg=aeon
53+
54+
Note: The project uses aeon, numpy, matplotlib, and other standard scientific Python packages.
55+
56+
---
57+
58+
## 🧪 Running KASBA
59+
60+
Minimal example from the kasba.ipynb notebook:
61+
62+
from kasba import KASBA
63+
from aeon.datasets import load_dataset
64+
65+
X, y = load_dataset("GunPoint")
66+
67+
model = KASBA(
68+
n_clusters=2,
69+
distance="msm",
70+
distance_params={
71+
"c": 1.0
72+
},
73+
)
74+
75+
labels = model.fit_predict(X)
76+
77+
The notebook demonstrates:
78+
79+
- How to use KASBA with different elastic distances
80+
- How to cluster multivariate or unequal-length time series
81+
- How to run multiple initialisations
82+
- How to inspect convergence behaviour
83+
84+
---
85+
86+
## 📊 Reproducing Figures (CD & MCM)
87+
88+
Use the result_visualisation.ipynb notebook to generate:
89+
90+
- Critical Difference diagrams
91+
- Model Comparison Matrices
92+
- Ranking curves and statistical tests
93+
94+
---
95+
96+
## 📜 Citation
97+
98+
If you use KASBA in academic work, please cite the paper:
99+
100+
C. Holder, A. Bagnall, Rock the kasba: Blazingly fast and accurate time
101+
series clustering, arXiv preprint arXiv:2411.17838 (2024)
102+
103+
(A full BibTeX entry will be added once the paper is published.)
104+
105+
---
106+
107+
## 🤝 Contact
108+
109+
For questions or queries please open an issue on tsml-eval.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Files for Rock the KASBA."""
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import sys
2+
3+
import numpy as np
4+
5+
from tsml_eval.experiments import (
6+
run_clustering_experiment as tsml_clustering_experiment,
7+
)
8+
from tsml_eval.publications.clustering.kasba._model_configuration import (
9+
EXPERIMENT_MODELS,
10+
)
11+
from tsml_eval.publications.clustering.kasba._utils import (
12+
_parse_command_line_bool,
13+
check_experiment_results_exist,
14+
load_dataset_from_file,
15+
)
16+
17+
18+
def run_threaded_clustering_experiment(
19+
dataset: str,
20+
clusterer_name: str,
21+
dataset_path: str,
22+
results_path: str,
23+
combine_test_train: bool,
24+
resample_id: int,
25+
):
26+
"""Run clustering experiment.
27+
28+
Parameters
29+
----------
30+
dataset : str
31+
Dataset name.
32+
distance : str
33+
Distance string (assumed correct and final), e.g.:
34+
"msm", "dtw", "soft_msm", "soft_dtw",
35+
"soft_divergence_msm", "soft_divergence_dtw".
36+
clusterer_str : str
37+
Free-form label used only for naming/logging (not logic).
38+
dataset_path : str
39+
Path to the dataset.
40+
results_path : str
41+
Path to the results.
42+
averaging_method : str
43+
One of: "soft", "kasba", "petitjean_ba", "subgradient_ba".
44+
combine_test_train : bool, default=False
45+
Boolean indicating if data should be combined for test and train.
46+
resample_id : int, default=0
47+
Integer indicating the resample id.
48+
n_jobs : int default=-1
49+
Integer indicating the number of jobs to run in parallel.
50+
"""
51+
if clusterer_name not in EXPERIMENT_MODELS:
52+
raise ValueError(f"Unknown clusterer_name '{clusterer_name}'")
53+
54+
# Skip if results already exist
55+
if check_experiment_results_exist(
56+
model_name=clusterer_name,
57+
dataset=dataset,
58+
combine_test_train=combine_test_train,
59+
path_to_results=results_path,
60+
resample_id=resample_id,
61+
):
62+
return (
63+
f"[SKIP] {clusterer_name} (resample {resample_id}): "
64+
f"results already exist."
65+
)
66+
67+
X_train, y_train, X_test, y_test = load_dataset_from_file(
68+
dataset,
69+
dataset_path,
70+
normalize=True,
71+
combine_test_train=combine_test_train,
72+
resample_id=0,
73+
)
74+
n_clusters = np.unique(y_train).size
75+
76+
factory = EXPERIMENT_MODELS[clusterer_name]
77+
clusterer = factory(
78+
n_clusters=n_clusters,
79+
random_state=resample_id,
80+
n_jobs=1,
81+
)
82+
83+
tsml_clustering_experiment(
84+
X_train=X_train,
85+
y_train=y_train,
86+
clusterer=clusterer,
87+
results_path=results_path,
88+
X_test=X_test,
89+
y_test=y_test,
90+
n_clusters=n_clusters,
91+
clusterer_name=clusterer_name,
92+
dataset_name=dataset,
93+
resample_id=resample_id,
94+
data_transforms=None,
95+
build_test_file=not combine_test_train,
96+
build_train_file=True,
97+
benchmark_time=True,
98+
)
99+
print(f"[DONE] {clusterer_name} (resample {resample_id})")
100+
101+
102+
# Boolean to toggle if running locally or via command line.
103+
RUN_LOCALLY = True
104+
105+
if __name__ == "__main__":
106+
"""NOTE: To run with command line arguments, set RUN_LOCALLY to False."""
107+
if RUN_LOCALLY:
108+
print("RUNNING WITH TEST CONFIG")
109+
110+
dataset = "GunPoint"
111+
clusterer_name = "KASBA"
112+
combine_test_train = True
113+
114+
dataset_path = (
115+
"/Users/chrisholder/Documents/Research/datasets/UCR/Univariate_ts"
116+
)
117+
results_path = "/Users/chrisholder/projects/kasba-experiments/full_results"
118+
run_threaded_clustering_experiment(
119+
dataset=dataset,
120+
clusterer_name=clusterer_name,
121+
dataset_path=dataset_path,
122+
results_path=results_path,
123+
combine_test_train=combine_test_train,
124+
resample_id=0,
125+
)
126+
127+
else:
128+
if len(sys.argv) != 6:
129+
print(
130+
"Usage: python _clustering_experiment_all.py "
131+
"<dataset> <clusterer_name> <dataset_path> <result_path> "
132+
"<combine_test_train>"
133+
)
134+
sys.exit(1)
135+
136+
dataset = str(sys.argv[1])
137+
clusterer_name = str(sys.argv[2])
138+
dataset_path = str(sys.argv[3])
139+
results_path = str(sys.argv[4])
140+
combine_test_train = _parse_command_line_bool(sys.argv[5])
141+
142+
run_threaded_clustering_experiment(
143+
dataset=dataset,
144+
clusterer_name=clusterer_name,
145+
dataset_path=dataset_path,
146+
results_path=results_path,
147+
combine_test_train=combine_test_train,
148+
resample_id=1,
149+
)

0 commit comments

Comments
 (0)