Skip to content

Commit 67b1b43

Browse files
authored
[BUG] MSM bug fixes (#3121)
* minor msm bug fix * msm bug fixes * updated tests * remove redundent distance call
1 parent e2e64ce commit 67b1b43

File tree

3 files changed

+26
-24
lines changed

3 files changed

+26
-24
lines changed

aeon/distances/elastic/_msm.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,10 @@ def _msm_independent_cost_matrix(
246246
x_size = x.shape[1]
247247
y_size = y.shape[1]
248248
cost_matrix = np.zeros((x_size, y_size))
249-
distance = 0
250249
min_instances = min(x.shape[0], y.shape[0])
251250
for i in range(min_instances):
252251
curr_cost_matrix = _independent_cost_matrix(x[i], y[i], bounding_matrix, c)
253252
cost_matrix = np.add(cost_matrix, curr_cost_matrix)
254-
distance += curr_cost_matrix[-1, -1]
255253
return cost_matrix
256254

257255

@@ -267,19 +265,19 @@ def _independent_cost_matrix(
267265
for i in range(1, x_size):
268266
if bounding_matrix[i, 0]:
269267
cost = _cost_independent(x[i], x[i - 1], y[0], c)
270-
cost_matrix[i][0] = cost_matrix[i - 1][0] + cost
268+
cost_matrix[i, 0] = cost_matrix[i - 1, 0] + cost
271269

272-
for i in range(1, y_size):
273-
if bounding_matrix[0, i]:
274-
cost = _cost_independent(y[i], x[0], y[i - 1], c)
275-
cost_matrix[0][i] = cost_matrix[0][i - 1] + cost
270+
for j in range(1, y_size):
271+
if bounding_matrix[0, j]:
272+
cost = _cost_independent(y[j], x[0], y[j - 1], c)
273+
cost_matrix[0, j] = cost_matrix[0, j - 1] + cost
276274

277275
for i in range(1, x_size):
278276
for j in range(1, y_size):
279277
if bounding_matrix[i, j]:
280278
d1 = cost_matrix[i - 1][j - 1] + np.abs(x[i] - y[j])
281-
d2 = cost_matrix[i - 1][j] + _cost_independent(x[i], x[i - 1], y[j], c)
282-
d3 = cost_matrix[i][j - 1] + _cost_independent(y[j], x[i], y[j - 1], c)
279+
d2 = cost_matrix[i - 1, j] + _cost_independent(x[i], x[i - 1], y[j], c)
280+
d3 = cost_matrix[i, j - 1] + _cost_independent(y[j], x[i], y[j - 1], c)
283281

284282
cost_matrix[i, j] = min(d1, d2, d3)
285283

@@ -292,40 +290,44 @@ def _msm_dependent_cost_matrix(
292290
) -> np.ndarray:
293291
x_size = x.shape[1]
294292
y_size = y.shape[1]
293+
295294
cost_matrix = np.full((x_size, y_size), np.inf)
296-
cost_matrix[0, 0] = np.sum(np.abs(x[:, 0] - y[:, 0]))
295+
cost_matrix[0, 0] = _univariate_squared_distance(x[:, 0], y[:, 0])
297296

298297
for i in range(1, x_size):
299298
if bounding_matrix[i, 0]:
300299
cost = _cost_dependent(x[:, i], x[:, i - 1], y[:, 0], c)
301-
cost_matrix[i][0] = cost_matrix[i - 1][0] + cost
302-
for i in range(1, y_size):
303-
if bounding_matrix[0, i]:
304-
cost = _cost_dependent(y[:, i], x[:, 0], y[:, i - 1], c)
305-
cost_matrix[0][i] = cost_matrix[0][i - 1] + cost
300+
cost_matrix[i, 0] = cost_matrix[i - 1, 0] + cost
301+
302+
for j in range(1, y_size):
303+
if bounding_matrix[0, j]:
304+
cost = _cost_dependent(y[:, j], x[:, 0], y[:, j - 1], c)
305+
cost_matrix[0, j] = cost_matrix[0, j - 1] + cost
306306

307307
for i in range(1, x_size):
308308
for j in range(1, y_size):
309309
if bounding_matrix[i, j]:
310-
d1 = cost_matrix[i - 1][j - 1] + np.sum(np.abs(x[:, i] - y[:, j]))
311-
d2 = cost_matrix[i - 1][j] + _cost_dependent(
310+
d1 = cost_matrix[i - 1, j - 1] + _univariate_squared_distance(
311+
x[:, i], y[:, j]
312+
)
313+
d2 = cost_matrix[i - 1, j] + _cost_dependent(
312314
x[:, i], x[:, i - 1], y[:, j], c
313315
)
314-
d3 = cost_matrix[i][j - 1] + _cost_dependent(
316+
d3 = cost_matrix[i, j - 1] + _cost_dependent(
315317
y[:, j], x[:, i], y[:, j - 1], c
316318
)
317-
318319
cost_matrix[i, j] = min(d1, d2, d3)
320+
319321
return cost_matrix
320322

321323

322324
@njit(cache=True, fastmath=True)
323325
def _cost_dependent(x: np.ndarray, y: np.ndarray, z: np.ndarray, c: float) -> float:
324326
diameter = _univariate_squared_distance(y, z)
325-
mid = (y + z) / 2
327+
mid = (y + z) / 2.0
326328
distance_to_mid = _univariate_squared_distance(mid, x)
327329

328-
if distance_to_mid <= (diameter / 2):
330+
if distance_to_mid <= (diameter / 4.0):
329331
return c
330332
else:
331333
dist_to_q_prev = _univariate_squared_distance(y, x)

aeon/distances/elastic/tests/test_distance_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"wddtw": [38144.53125, 19121.4927, 1.34957],
6161
"twe": [4536.0, 3192.0220, 3030.036000000001],
6262
"msm_ind": [1515.0, 1517.8000000000004, 1557.0], # msm with independent distance
63-
"msm_dep": [1897.0, 1898.6000000000001, 1921.0], # msm with dependent distance
63+
"msm_dep": [190547.0, 190549.800000000020, 190589.0], # msm with dependent distance
6464
}
6565
basic_motions_distances = {
6666
"euclidean": 27.51835240,
@@ -78,7 +78,7 @@
7878
# msm with independent distance
7979
"msm_ind": [84.36021099999999, 140.13788899999997, 262.6939920000001],
8080
# msm with dependent distance
81-
"msm_dep": [33.06825, 71.1408, 190.7397],
81+
"msm_dep": [192.24477562339214, 205.82382238128477, 277.7315058567359],
8282
}
8383

8484

aeon/testing/expected_results/expected_distance_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@
166166
"msm": [
167167
[2.5687181410313746, 28.97170437295012],
168168
[2.5687181410313746, 28.97170437295012],
169-
[2.5687181410313746, 28.657118461088324],
169+
[1.0924085990342982, 13.072037194954508],
170170
[1.8756413986565008, 22.362537814430787],
171171
],
172172
"adtw": [

0 commit comments

Comments
 (0)