Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest
from jax.sharding import Mesh

from tpu_inference.layers.common.attention_interface import attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention_interface import attention
from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh

# ---- Test Configuration & Constants ----
Expand Down Expand Up @@ -79,12 +79,12 @@ def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):

if head_dim == 64:
monkeypatch.setattr(
"tpu_inference.layers.jax.attention_interface.ragged_paged_attention_hd64",
"tpu_inference.layers.common.attention_interface.ragged_paged_attention_hd64",
mock_paged_attn_kernel,
)
else:
monkeypatch.setattr(
"tpu_inference.layers.jax.attention_interface.ragged_paged_attention",
"tpu_inference.layers.common.attention_interface.ragged_paged_attention",
mock_paged_attn_kernel,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/layers/jax/attention/test_common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from jax.sharding import Mesh
from parameterized import parameterized

from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention.attention import Attention
from tpu_inference.layers.jax.attention_interface import get_kv_cache_shape

KVCache = Tuple[jax.Array, jax.Array]

Expand Down
2 changes: 1 addition & 1 deletion tests/layers/jax/attention/test_deepseek_v3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from jax.sharding import Mesh
from parameterized import parameterized

from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
from tpu_inference.layers.jax.attention_interface import get_kv_cache_shape


class TestMLA(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/jax/attention/test_llama4_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from jax.sharding import PartitionSpec as P

from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.sharding import build_mesh
from tpu_inference.layers.jax.attention.llama4_attention import (
L2Norm, Llama4Attention)
from tpu_inference.layers.jax.sharding import build_mesh


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions tests/layers/jax/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import jax

from tpu_inference.layers.jax.sharding import (Sharding, ShardingConfig,
ShardingRulesConfig,
ShardingStrategy)
from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
ShardingRulesConfig,
ShardingStrategy)


class TestSharding(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
from tpu_inference.kernels.flash_attention.kernel import flash_attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.utils import get_megacore

MAX_ALLOWED_PAGE_INDICES_N = (
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/jax/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
ragged_paged_attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.layers.jax.base import create_param
from tpu_inference.layers.jax.rope_interface import apply_rope
from tpu_inference.layers.jax.sharding import ShardingAxisName

KVCache = Tuple[jax.Array, jax.Array]

Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/jax/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jax.numpy as jnp
import numpy as np

from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
from tpu_inference.layers.jax.sample.sampling_metadata import \
TPUSupportedSamplingMetadata

Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/layers/jax/sample/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from jax.sharding import PartitionSpec as P
from vllm.v1.outputs import LogprobsTensors

from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.layers.jax.sample.sampling_metadata import \
TPUSupportedSamplingMetadata
from tpu_inference.layers.jax.sharding import ShardingAxisName

_SAMPLING_EPS = 1e-5

Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/vllm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
AttentionLayer, AttentionType)

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention_interface import attention
from tpu_inference.logger import init_logger
from tpu_inference.models.vllm.vllm_model_wrapper_context import \
get_vllm_model_wrapper_context
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.utils.func_utils import supports_kw

from tpu_inference import envs
from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
apply_qwix_on_abstract_model, apply_qwix_quantization,
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/models/jax/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from vllm.config import VllmConfig

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention_interface import attention
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.layers.jax.rope_interface import apply_rope
from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
load_hf_weights)
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/models/jax/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from vllm.config import VllmConfig

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention_interface import attention
from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/models/jax/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from vllm.config import VllmConfig

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention_interface import attention
from tpu_inference.layers.jax.rope_interface import apply_rope
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/models/jax/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from vllm.config import VllmConfig

from tpu_inference import utils as utils
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention_interface import \
from tpu_inference.layers.common.attention_interface import \
sharded_flash_attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
# from vllm.model_executor.models.interfaces import MultiModalEmbeddings
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/models/jax/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from vllm.config import VllmConfig

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.jax.attention_interface import attention
from tpu_inference.layers.jax.rope_interface import apply_rope
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.sampling_params import SamplingParams, SamplingType

from tpu_inference import envs
from tpu_inference.layers.jax.sharding import ShardingConfigManager
from tpu_inference.layers.common.sharding import ShardingConfigManager
from tpu_inference.logger import init_logger

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/runner/compilation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

from tpu_inference.core.disagg_utils import is_disagg_enabled
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.layers.jax.sample.sampling import sample
from tpu_inference.layers.jax.sample.sampling_metadata import \
TPUSupportedSamplingMetadata
from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.logger import init_logger
from tpu_inference.utils import device_array

Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/runner/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.logger import init_logger

logger = init_logger(__name__)
Expand Down
8 changes: 4 additions & 4 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@

from tpu_inference import utils as common_utils
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
MESH_AXIS_NAMES_2D,
ShardingAxisName,
ShardingConfigManager)
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
gather_logprobs, sample)
from tpu_inference.layers.jax.sample.sampling_metadata import \
TPUSupportedSamplingMetadata
from tpu_inference.layers.jax.sharding import (MESH_AXIS_NAMES,
MESH_AXIS_NAMES_2D,
ShardingAxisName,
ShardingConfigManager)
from tpu_inference.logger import init_logger
from tpu_inference.models.common.model_loader import get_model
from tpu_inference.models.jax.utils.weight_utils import (
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tpu_inference import envs, utils
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
get_node_id)
from tpu_inference.layers.jax.sharding import ShardingConfigManager
from tpu_inference.layers.common.sharding import ShardingConfigManager
from tpu_inference.logger import init_logger
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
from tpu_inference.runner.tpu_runner import TPUModelRunner
Expand Down