Skip to content

Commit 1fa64c2

Browse files
committed
[Refactor] Move shared files to common
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent b832b02 commit 1fa64c2

23 files changed

+31
-31
lines changed

tests/models/jax/test_attention_interface.py renamed to tests/layers/common/test_attention_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import pytest
77
from jax.sharding import Mesh
88

9+
from tpu_inference.layers.common.attention_interface import attention
910
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
10-
from tpu_inference.layers.jax.attention_interface import attention
1111
from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh
1212

1313
# ---- Test Configuration & Constants ----
@@ -79,12 +79,12 @@ def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):
7979

8080
if head_dim == 64:
8181
monkeypatch.setattr(
82-
"tpu_inference.layers.jax.attention_interface.ragged_paged_attention_hd64",
82+
"tpu_inference.layers.common.attention_interface.ragged_paged_attention_hd64",
8383
mock_paged_attn_kernel,
8484
)
8585
else:
8686
monkeypatch.setattr(
87-
"tpu_inference.layers.jax.attention_interface.ragged_paged_attention",
87+
"tpu_inference.layers.common.attention_interface.ragged_paged_attention",
8888
mock_paged_attn_kernel,
8989
)
9090

tests/layers/jax/attention/test_common_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from jax.sharding import Mesh
99
from parameterized import parameterized
1010

11+
from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
1112
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1213
from tpu_inference.layers.jax.attention.attention import Attention
13-
from tpu_inference.layers.jax.attention_interface import get_kv_cache_shape
1414

1515
KVCache = Tuple[jax.Array, jax.Array]
1616

tests/layers/jax/attention/test_deepseek_v3_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from jax.sharding import Mesh
88
from parameterized import parameterized
99

10+
from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
1011
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1112
from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
12-
from tpu_inference.layers.jax.attention_interface import get_kv_cache_shape
1313

1414

1515
class TestMLA(unittest.TestCase):

tests/layers/jax/attention/test_llama4_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from jax.sharding import PartitionSpec as P
1414

1515
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16+
from tpu_inference.layers.common.sharding import build_mesh
1617
from tpu_inference.layers.jax.attention.llama4_attention import (
1718
L2Norm, Llama4Attention)
18-
from tpu_inference.layers.jax.sharding import build_mesh
1919

2020

2121
@dataclass

tests/layers/jax/test_sharding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import jax
55

6-
from tpu_inference.layers.jax.sharding import (Sharding, ShardingConfig,
7-
ShardingRulesConfig,
8-
ShardingStrategy)
6+
from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
7+
ShardingRulesConfig,
8+
ShardingStrategy)
99

1010

1111
class TestSharding(unittest.TestCase):

tpu_inference/layers/jax/attention_interface.py renamed to tpu_inference/layers/common/attention_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
1818
from tpu_inference.kernels.flash_attention.kernel import flash_attention
1919
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20-
from tpu_inference.layers.jax.sharding import ShardingAxisName
20+
from tpu_inference.layers.common.sharding import ShardingAxisName
2121
from tpu_inference.utils import get_megacore
2222

2323
MAX_ALLOWED_PAGE_INDICES_N = (
File renamed without changes.

tpu_inference/layers/jax/attention/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
1414
ragged_paged_attention
1515
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16+
from tpu_inference.layers.common.sharding import ShardingAxisName
1617
from tpu_inference.layers.jax.base import create_param
1718
from tpu_inference.layers.jax.rope_interface import apply_rope
18-
from tpu_inference.layers.jax.sharding import ShardingAxisName
1919

2020
KVCache = Tuple[jax.Array, jax.Array]
2121

tpu_inference/layers/jax/sample/rejection_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import jax.numpy as jnp
1313
import numpy as np
1414

15-
from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
15+
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
1616
from tpu_inference.layers.jax.sample.sampling_metadata import \
1717
TPUSupportedSamplingMetadata
1818

0 commit comments

Comments
 (0)