Skip to content

Commit 679ae1d

Browse files
nmharmon8jongwook
andauthored
Fix: Ensure DTW cost tensor is on the same device as input tensor (#2561)
Co-authored-by: Jong Wook Kim <[email protected]>
1 parent f50c4f2 commit 679ae1d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

whisper/timing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
117117
x_skew = x_skew.T.contiguous()
118118
cost = torch.ones(N + M + 2, M + 2) * np.inf
119119
cost[0, 0] = 0
120-
cost = cost.cuda()
120+
cost = cost.to(x.device)
121121
trace = torch.zeros_like(cost, dtype=torch.int32)
122122

123123
dtw_kernel[(1,)](

0 commit comments

Comments
 (0)