Skip to content

Commit f281038

Browse files
committed
msm bug fixes
1 parent 65ba1bf commit f281038

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

aeon/distances/elastic/_msm.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -267,19 +267,19 @@ def _independent_cost_matrix(
267267
for i in range(1, x_size):
268268
if bounding_matrix[i, 0]:
269269
cost = _cost_independent(x[i], x[i - 1], y[0], c)
270-
cost_matrix[i][0] = cost_matrix[i - 1][0] + cost
270+
cost_matrix[i, 0] = cost_matrix[i - 1, 0] + cost
271271

272-
for i in range(1, y_size):
273-
if bounding_matrix[0, i]:
274-
cost = _cost_independent(y[i], x[0], y[i - 1], c)
275-
cost_matrix[0][i] = cost_matrix[0][i - 1] + cost
272+
for j in range(1, y_size):
273+
if bounding_matrix[0, j]:
274+
cost = _cost_independent(y[j], x[0], y[j - 1], c)
275+
cost_matrix[0, j] = cost_matrix[0, j - 1] + cost
276276

277277
for i in range(1, x_size):
278278
for j in range(1, y_size):
279279
if bounding_matrix[i, j]:
280280
d1 = cost_matrix[i - 1][j - 1] + np.abs(x[i] - y[j])
281-
d2 = cost_matrix[i - 1][j] + _cost_independent(x[i], x[i - 1], y[j], c)
282-
d3 = cost_matrix[i][j - 1] + _cost_independent(y[j], x[i], y[j - 1], c)
281+
d2 = cost_matrix[i - 1, j] + _cost_independent(x[i], x[i - 1], y[j], c)
282+
d3 = cost_matrix[i, j - 1] + _cost_independent(y[j], x[i], y[j - 1], c)
283283

284284
cost_matrix[i, j] = min(d1, d2, d3)
285285

@@ -292,40 +292,44 @@ def _msm_dependent_cost_matrix(
292292
) -> np.ndarray:
293293
x_size = x.shape[1]
294294
y_size = y.shape[1]
295+
295296
cost_matrix = np.full((x_size, y_size), np.inf)
296-
cost_matrix[0, 0] = np.sum(np.abs(x[:, 0] - y[:, 0]))
297+
cost_matrix[0, 0] = _univariate_squared_distance(x[:, 0], y[:, 0])
297298

298299
for i in range(1, x_size):
299300
if bounding_matrix[i, 0]:
300301
cost = _cost_dependent(x[:, i], x[:, i - 1], y[:, 0], c)
301-
cost_matrix[i][0] = cost_matrix[i - 1][0] + cost
302-
for i in range(1, y_size):
303-
if bounding_matrix[0, i]:
304-
cost = _cost_dependent(y[:, i], x[:, 0], y[:, i - 1], c)
305-
cost_matrix[0][i] = cost_matrix[0][i - 1] + cost
302+
cost_matrix[i, 0] = cost_matrix[i - 1, 0] + cost
303+
304+
for j in range(1, y_size):
305+
if bounding_matrix[0, j]:
306+
cost = _cost_dependent(y[:, j], x[:, 0], y[:, j - 1], c)
307+
cost_matrix[0, j] = cost_matrix[0, j - 1] + cost
306308

307309
for i in range(1, x_size):
308310
for j in range(1, y_size):
309311
if bounding_matrix[i, j]:
310-
d1 = cost_matrix[i - 1][j - 1] + np.sum(np.abs(x[:, i] - y[:, j]))
311-
d2 = cost_matrix[i - 1][j] + _cost_dependent(
312+
d1 = cost_matrix[i - 1, j - 1] + _univariate_squared_distance(
313+
x[:, i], y[:, j]
314+
)
315+
d2 = cost_matrix[i - 1, j] + _cost_dependent(
312316
x[:, i], x[:, i - 1], y[:, j], c
313317
)
314-
d3 = cost_matrix[i][j - 1] + _cost_dependent(
318+
d3 = cost_matrix[i, j - 1] + _cost_dependent(
315319
y[:, j], x[:, i], y[:, j - 1], c
316320
)
317-
318321
cost_matrix[i, j] = min(d1, d2, d3)
322+
319323
return cost_matrix
320324

321325

322326
@njit(cache=True, fastmath=True)
323327
def _cost_dependent(x: np.ndarray, y: np.ndarray, z: np.ndarray, c: float) -> float:
324328
diameter = _univariate_squared_distance(y, z)
325-
mid = (y + z) / 2
329+
mid = (y + z) / 2.0
326330
distance_to_mid = _univariate_squared_distance(mid, x)
327331

328-
if distance_to_mid <= (diameter / 2):
332+
if distance_to_mid <= (diameter / 4.0):
329333
return c
330334
else:
331335
dist_to_q_prev = _univariate_squared_distance(y, x)

0 commit comments

Comments
 (0)