diff --git a/python/sgl_jax/srt/layers/sampler.py b/python/sgl_jax/srt/layers/sampler.py index 36656cd11..c687b0a8f 100644 --- a/python/sgl_jax/srt/layers/sampler.py +++ b/python/sgl_jax/srt/layers/sampler.py @@ -4,8 +4,7 @@ from jax import lax from jax import numpy as jnp from jax import random -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P +from jax.sharding import Mesh from sgl_jax.srt.layers.binary_search import topk_mask, topp_mask from sgl_jax.srt.layers.logits_processor import LogitsProcessorOutput @@ -30,8 +29,6 @@ def _regular_sampling(self, operands): """Regular sampling branch""" logits, sampling_metadata, positions, rng, mesh, use_sort_for_toppk_minp = operands - logits = lax.with_sharding_constraint(logits, NamedSharding(mesh, P(None, None))) - # Validate broadcast compatibility for temperature division logits_batch_size = logits.shape[0] temperatures_shape = sampling_metadata.temperatures.shape