Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c855f44
add xai temperature
Kipsora Nov 3, 2025
cedf0c0
adjust to the new grok model
Kipsora Nov 3, 2025
38e6d68
fix bugs
Kipsora Nov 3, 2025
02af531
fix bugs
Kipsora Nov 3, 2025
e9d3571
fix bugs
Kipsora Nov 3, 2025
cfcbeeb
fix bugs
Kipsora Nov 3, 2025
626dbae
fix bugs
Kipsora Nov 3, 2025
fdbaf1b
fix bugs
Kipsora Nov 3, 2025
af65b21
workable version with one layer and log
Bob-Chen222 Nov 6, 2025
ad7c898
delete redundant print in jax compiled codes
Bob-Chen222 Nov 6, 2025
acf3cb7
reproduce OOM error on tpuv6e-32 and huggingface downloading stuck
Bob-Chen222 Nov 7, 2025
c07a136
grok-oom
Prayer3th Nov 7, 2025
6a53f8b
Merge pull request #1 from Prayer3th/feat/grok-rebase-bob
Bob-Chen222 Nov 7, 2025
5904079
grok support weight loading debug WIP
Bob-Chen222 Nov 10, 2025
fd500f6
debug
JamesBrianD Nov 10, 2025
1c1f8e6
fix: moe
JamesBrianD Nov 10, 2025
071a8d5
Merge pull request #2 from primatrix/debug/grok2
Bob-Chen222 Nov 10, 2025
7a593c6
hack layers
JamesBrianD Nov 11, 2025
77ba335
hack layers
JamesBrianD Nov 11, 2025
4b4e059
hack layers
JamesBrianD Nov 11, 2025
02eebb0
hack layers
JamesBrianD Nov 11, 2025
ce667a5
hack layers
JamesBrianD Nov 11, 2025
9c64cad
hack layers
JamesBrianD Nov 11, 2025
61c7d2f
hack layers
JamesBrianD Nov 11, 2025
62b12fb
hack layers
JamesBrianD Nov 11, 2025
39dc0ea
hack layers
JamesBrianD Nov 11, 2025
3eb6972
hack layers
JamesBrianD Nov 11, 2025
71825fa
hack layers
JamesBrianD Nov 11, 2025
d8b379d
hack layers
JamesBrianD Nov 11, 2025
6362b06
io callback
JamesBrianD Nov 11, 2025
a87a3e7
io callback
JamesBrianD Nov 11, 2025
c33d9d3
Merge pull request #3 from primatrix/hack-layers
Bob-Chen222 Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
run: |
cd sglang-jax && source .venv/bin/activate
# cd python/sgl_jax && \
# python3 bench_one_batch.py \
# --model-path xai-org/grok-2 \
# --tokenizer-path Xenova/grok-1-tokenizer \
# --correct \
# --tp-size 32 \
# --mem-fraction-static 0.4 \
# --download-dir /mnt \
# --load-format dummy

# JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache uv run python \
# -u -m sgl_jax.launch_server --model-path Qwen/Qwen-7B-Chat --trust-remote-code \
# --dist-init-addr=0.0.0.0:10011 --nnodes=1 --tp-size=1 --device=tpu --random-seed=3 \
# --node-rank=0 --mem-fraction-static=0.8 --max-prefill-tokens=8192 --download-dir=/tmp \
# --dtype=bfloat16 --skip-server-warmup --host 0.0.0.0 --port 30000

# cd python/sgl_jax && python3 bench_one_batch.py --model-path xai-org/grok-2 \
# --load-format auto --tokenizer-path Xenova/grok-1-tokenizer --correct --tp-size 8 --download-dir /models/

# rm -rf ~/.cache/ && \
cd python/sgl_jax && python3 -u bench_one_batch.py --model-path /models/xai-org-backup/ --load-format auto --tokenizer-path Xenova/grok-1-tokenizer --correctness-test --tp-size 4 \
--mem-fraction-static 0.9
# --nnodes 8 --dist-init-addr=0.0.0.0:10011
44 changes: 30 additions & 14 deletions python/sgl_jax/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ def from_cli_args(cls, args: argparse.Namespace):

def load_model(server_args, port_args, tp_rank):
# TODO: pass in tp_size
# server_args.tp_size = 1
# server_args.tp_size = 16
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)

model_config = ModelConfig.from_server_args(server_args)
# logging.info("load_model num_hidden_layers: %s", model_config.num_hidden_layers)
# model_config.num_hidden_layers = 1 #for debugging


# Create a mesh that includes both 'data' and 'tensor' axes.
# Use a size-1 'data' axis and shard across the 'tensor' axis per tp_size.
Expand All @@ -137,6 +140,7 @@ def load_model(server_args, port_args, tp_rank):
mesh=mesh,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")

tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
Expand All @@ -146,7 +150,7 @@ def load_model(server_args, port_args, tp_rank):
try:
jax_mh.sync_global_devices("load_model")
except Exception as err:
logging.info("Could not sync global devices (expected in single-host): %s", err)
logging.info("Could not sync global devices (expected in single-host): %s", err)
return model_runner, tokenizer


Expand Down Expand Up @@ -238,13 +242,14 @@ def extend(reqs, model_runner):
return next_token_ids, next_token_logits, batch


def decode(input_token_ids, batch, model_runner):
def decode(input_token_ids, batch: ScheduleBatch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
_maybe_prepare_mlp_sync_batch(batch, model_runner)
# For decode, the token dimension equals current batch size
bs_needed = len(batch.seq_lens)
next_token_ids, next_token_logits = _run_forward_and_sample(model_runner, batch, bs_needed)
logging.info("next token logits: %s", next_token_logits)
return next_token_ids, next_token_logits


Expand All @@ -269,7 +274,7 @@ def _run_forward_and_sample(model_runner, batch: ScheduleBatch, token_first_arg:
)
)

model_worker_batch = batch.get_model_worker_batch(
model_worker_batch : ModelWorkerBatch = batch.get_model_worker_batch(
[token_first_arg], [bs_needed], [cache_loc_needed], page_size
)

Expand All @@ -281,17 +286,19 @@ def _run_forward_and_sample(model_runner, batch: ScheduleBatch, token_first_arg:
logits_metadata = LogitsMetadata.from_model_worker_batch(
model_worker_batch, mesh=model_runner.mesh
)
positions = model_worker_batch.positions

logits_output, _ = model_runner.forward(forward_batch, logits_metadata=logits_metadata)

pad_size = len(model_worker_batch.seq_lens) - model_worker_batch.real_bs
sampling_metadata = SamplingMetadata.from_model_worker_batch(
model_worker_batch,
pad_size=pad_size,
mesh=model_runner.mesh,
vocab_size=model_runner.model_config.vocab_size,
)
next_token_ids = model_runner.sample(logits_output, sampling_metadata)
next_token_ids = model_runner.sample(logits_output, sampling_metadata, positions)
# NOTE(Bob): seems that now next_token_ids is a jax array, not a numpy array

return next_token_ids, logits_output.next_token_logits

Expand All @@ -316,22 +323,29 @@ def correctness_test(
if bench_args.cut_len > 0:
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (first half): {next_token_logits} \n")
if tp_rank == 0:
gathered_logits = jax_mh.process_allgather(next_token_logits, tiled=True)
rank_print(f"prefill logits (first half): {gathered_logits} \n")

# Prepare extend inputs
reqs = prepare_extend_inputs_for_correctness_test(bench_args, input_ids, reqs, model_runner)

# Extend (prefill w/ KV cache)
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (final): {next_token_logits} \n")
gathered_token_ids = jax_mh.process_allgather(next_token_ids, tiled=True)
next_token_ids_cpu = np.array(gathered_token_ids)
if tp_rank == 0:
gathered_logits = jax_mh.process_allgather(next_token_logits, tiled=True)
rank_print(f"prefill logits (final): {gathered_logits} \n")

# Decode
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
next_token_ids_list = next_token_ids.tolist()
output_ids = [input_ids[i] + [next_token_ids_cpu[i]] for i in range(len(input_ids))]
for step in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids_cpu, batch, model_runner)
gathered_token_ids = jax_mh.process_allgather(next_token_ids, tiled=True)
next_token_ids_cpu = np.array(gathered_token_ids)
for i in range(len(reqs)):
output_ids[i].append(next_token_ids_list[i])
output_ids[i].append(next_token_ids_cpu[i])

# Print output texts
for i in range(len(reqs)):
Expand Down Expand Up @@ -549,7 +563,7 @@ def main(server_args, bench_args):
"Switching attention backend to 'native' for single TPU to reduce compile-time memory"
)

_set_envs_and_config()
_set_envs_and_config(server_args)

if server_args.model_path:
work_func = correctness_test if bench_args.correctness_test else latency_test
Expand All @@ -565,6 +579,8 @@ def main(server_args, bench_args):


if __name__ == "__main__":
jax.distributed.initialize()
logging.info("JAX distributed initialized")
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
Expand Down
17 changes: 17 additions & 0 deletions python/sgl_jax/srt/layers/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import jax
import jax.numpy as jnp
from flax import nnx


class GeluAndMul(nnx.Module):
def __init__(self, approximate: str = "tanh"):
self.approximate = approximate

def __call__(self, gate_up: jax.Array):
gate, up = jnp.split(gate_up, 2, axis=-1)
if self.approximate == "tanh":
gelu = jax.nn.gelu(gate, approximate=True)
else:
gelu = jax.nn.gelu(gate, approximate=False)
out = gelu * up
return out, None
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def ref_ragged_paged_attention_fused(
mask_value: float | None = DEFAULT_MASK_VALUE,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
Expand Down Expand Up @@ -86,6 +87,19 @@ def ref_ragged_paged_attention_fused(
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
if soft_cap is not None:
attn = soft_cap * jnp.tanh(attn / soft_cap)

# xai_temperature_scale: ref implementation from sgl-project/sglang
# python/sglang/srt/layers/attention/triton_ops/decode_attention.py
if xai_temperature_len is not None:
# FIXED: Calculate `qidx` based on the sequence's own length, not global batch indices.
prefix_len = kv_len - q_len
qidx = jnp.arange(prefix_len, kv_len)

xai_temperature_scale = 1.0 / jnp.log2(float(xai_temperature_len))
_qtemp = jnp.log2(qidx.astype(jnp.float32)) * xai_temperature_scale
xai_temperature_reg = jnp.where(qidx > xai_temperature_len, _qtemp, 1.0)
attn = attn * xai_temperature_reg[:, None]

attn += jnp.where(mask, mask_value, 0.0)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
Expand All @@ -109,6 +123,7 @@ def ref_ragged_paged_attention(
mask_value: float | None = DEFAULT_MASK_VALUE,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
Expand Down Expand Up @@ -143,6 +158,16 @@ def ref_ragged_paged_attention(
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
if soft_cap is not None:
attn = soft_cap * jnp.tanh(attn / soft_cap)

# xai_temperature_scale: ref implementation from sgl-project/sglang
# python/sglang/srt/layers/attention/triton_ops/decode_attention.py
if xai_temperature_len is not None:
xai_temperature_scale = 1.0 / jnp.log2(float(xai_temperature_len))
qidx = jnp.arange(q_start, q_end) - 1
_qtemp = jnp.log2(qidx.astype(jnp.float32)) * xai_temperature_scale
xai_temperature_reg = jnp.where(qidx > xai_temperature_len, _qtemp, 1.0)
attn = attn * xai_temperature_reg[:, None]

attn += jnp.where(mask, mask_value, 0.0)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
Expand Down Expand Up @@ -260,6 +285,7 @@ def _ragged_paged_attention_kernel(
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
chunk_prefill_size: int | None = None,
bkv_p,
bq_sz,
Expand Down Expand Up @@ -652,12 +678,24 @@ def batch_prepare_queries():
for head_idx in range(actual_num_kv_heads):
bq = load_bq(bq_sem_idx, head_idx, actual_bq_sz=actual_bq_sz)
q_heads.append(bq)
q_batch = jnp.stack(q_heads, axis=0)

return jnp.stack(q_heads, axis=0)
if xai_temperature_len is not None:
# Correctly calculate sequence-relative position
prefix_len = kv_len - q_len
# `bq_idx * bq_sz` is the offset within the new queries for this sequence
local_q_offset = bq_idx * bq_sz + lax.iota(jnp.int32, q_batch.shape[1])
# `base_qidx` is the absolute sequence position
base_qidx = prefix_len + local_q_offset
# Tile for all KV heads, as the position is the same for each head group.
offs_qidx_batch = jnp.tile(base_qidx, (q_batch.shape[0], 1))
return q_batch, offs_qidx_batch

return q_batch, None

# Load batched data
k_batch, v_batch = batch_load_all_heads_kv()
q_batch = batch_prepare_queries()
q_batch, offs_qidx_batch = batch_prepare_queries()

def flash_attention(q_batch, k_batch, v_batch):
q_batch_f32 = q_batch.astype(jnp.float32)
Expand Down Expand Up @@ -697,6 +735,19 @@ def flash_attention(q_batch, k_batch, v_batch):
if soft_cap is not None:
s = soft_cap * jnp.tanh(s / soft_cap)

# xai_temperature_scale: ref implementation from sgl-project/sglang
# python/sglang/srt/layers/attention/triton_ops/decode_attention.py
if xai_temperature_len is not None:
xai_temperature_scale = 1.0 / jnp.log2(float(xai_temperature_len))
_qtemp = (
jnp.log2(offs_qidx_batch.astype(jnp.float32)) * xai_temperature_scale
)
xai_temperature_reg = jnp.where(
offs_qidx_batch > xai_temperature_len, _qtemp, 1.0
)

s = s * xai_temperature_reg[:, :, None]

s += jnp.where(mask, mask_value, 0.0)

for head_idx in range(actual_num_kv_heads):
Expand Down Expand Up @@ -950,6 +1001,7 @@ def static_validate_inputs_fused(
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
# Kernel optimization params.
chunk_prefill_size: int | None = None,
# Kernel tuning params.
Expand Down Expand Up @@ -1030,6 +1082,8 @@ def static_validate_inputs_fused(
raise ValueError(f"{sliding_window=} must be positive.")
if soft_cap is not None and soft_cap == 0.0:
raise ValueError(f"{soft_cap=} must not be 0.0.")
if xai_temperature_len is not None and xai_temperature_len <= 0:
raise ValueError(f"{xai_temperature_len=} must be positive.")
if chunk_prefill_size is not None and chunk_prefill_size <= 0:
raise ValueError(f"{chunk_prefill_size=} must be positive.")
if num_kv_pages_per_block is not None and num_kv_pages_per_block <= 0:
Expand All @@ -1044,6 +1098,7 @@ def static_validate_inputs_fused(
del q_scale
del k_scale
del v_scale
del xai_temperature_len


@functools.partial(
Expand All @@ -1056,6 +1111,7 @@ def static_validate_inputs_fused(
"q_scale",
"k_scale",
"v_scale",
"xai_temperature_len",
"chunk_prefill_size",
"num_kv_pages_per_block",
"num_queries_per_block",
Expand All @@ -1081,6 +1137,7 @@ def ragged_paged_attention(
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
xai_temperature_len: float | None = None,
# Kernel optimization params.
chunk_prefill_size: int | None = None,
# Kernel tuning params.
Expand Down Expand Up @@ -1110,6 +1167,8 @@ def ragged_paged_attention(
mask_value: mask value for causal mask.
k_scale: the scale for the key cache.
v_scale: the scale for the value cache.
xai_temperature_len: the length-based temperature term used by xai grok.
reference: sgl-project/sglang: python/sglang/srt/layers/attention/triton_ops/decode_attention.py
num_kv_pages_per_block: number of kv pages to be processed in one flash
attention block in the pallas kernel.
num_queries_per_block: number of kv pages to be processed in one flash
Expand Down Expand Up @@ -1270,6 +1329,7 @@ def ragged_paged_attention(
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
xai_temperature_len=xai_temperature_len,
chunk_prefill_size=chunk_prefill_size,
bq_sz=bq_sz,
bkv_p=bkv_p,
Expand Down
Loading