Skip to content

Commit b8d3757

Browse files
[Pallas/interpreter] Refactor re-useable functionality out of the main source file for the TPU kernel interpreter.
PiperOrigin-RevId: 845731033
1 parent e63d2a4 commit b8d3757

File tree

8 files changed

+527
-374
lines changed

8 files changed

+527
-374
lines changed

jax/_src/pallas/mosaic/interpret/BUILD

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ py_library(
3333
deps = [
3434
":race_detection_state",
3535
":shared_memory",
36+
":thread_map",
37+
":utils",
3638
":vector_clock",
3739
"//jax",
3840
"//jax/_src:api",
@@ -79,3 +81,23 @@ pytype_strict_library(
7981
"//jax/_src:source_info_util",
8082
],
8183
)
84+
85+
pytype_strict_library(
86+
name = "thread_map",
87+
srcs = ["thread_map.py"],
88+
deps = [
89+
"//jax",
90+
"//jax/_src:callback",
91+
],
92+
)
93+
94+
pytype_strict_library(
95+
name = "utils",
96+
srcs = ["utils.py"],
97+
deps = [
98+
"//jax",
99+
"//jax/_src:core",
100+
"//jax/_src:util",
101+
"//jax/_src/pallas",
102+
] + py_deps("numpy"),
103+
)

0 commit comments

Comments
 (0)