Skip to content

Commit 8689924

Browse files
Fixed triton kernel update to support latest triton versions (#2588)
* Update triton kernel using _unsafe_update_src * support old triton versions * refactored changes to update triton kernel only once * Update triton_ops.py --------- Co-authored-by: Jong Wook Kim <[email protected]> Co-authored-by: Jong Wook Kim <[email protected]>
1 parent 5dff4db commit 8689924

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

whisper/triton_ops.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def kernel(
6060
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
6161

6262
kernel = triton.JITFunction(kernel.fn)
63-
kernel.src = kernel.src.replace(
63+
new_kernel = kernel.src.replace(
6464
" LOAD_ALL_ROWS_HERE",
6565
"\n".join(
6666
[
@@ -69,7 +69,8 @@ def kernel(
6969
]
7070
),
7171
)
72-
kernel.src = kernel.src.replace(
72+
73+
new_kernel = new_kernel.replace(
7374
" BUBBLESORT_HERE",
7475
"\n\n".join(
7576
[
@@ -90,7 +91,14 @@ def kernel(
9091
]
9192
),
9293
)
93-
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
94+
95+
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
96+
97+
if hasattr(kernel, "_unsafe_update_src") is True:
98+
kernel._unsafe_update_src(new_kernel)
99+
kernel.hash = None
100+
else:
101+
kernel.src = new_kernel
94102

95103
return kernel
96104

0 commit comments

Comments
 (0)