Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/basic_usage/features/radix_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Key components:
- Uses min-heap to select leaf nodes by `last_access_time`
- Only evicts nodes with `lock_ref == 0`

The radix tree is stored on CPU while KV data uses GPU memory pools.
The radix tree is stored on CPU while KV data uses TPU memory pools.

## Usage

Expand Down
2 changes: 1 addition & 1 deletion docs/features/chunked_prefill.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The primary objectives of chunked prefill are:

Traditional LLM inference processes entire prefill requests as single units, which can cause issues with very long sequences:
- **Prefill phase**: Processes entire input prompt in parallel, but long sequences can exceed memory limits
- **Memory constraints**: Large prefill requests may not fit in available GPU memory or batch size limits
- **Memory constraints**: Large prefill requests may not fit in available TPU memory or batch size limits

Chunked prefill addresses this by breaking large prefill requests into smaller chunks that can be processed sequentially, allowing the system to handle much longer input sequences without running out of memory.

Expand Down
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"flax~=0.12.0",
"huggingface-hub~=0.34.3",
"jinja2~=3.1.6",
"llguidance~=1.3.0",
"modelscope~=1.28.2",
"msgpack-python~=0.5.6",
"numpy~=2.2.6",
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct

## Reference output (of the correctness test above, can be gpu dependent):
## Reference output (of the correctness test above, can be tpu dependent):
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]

prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_total_num_kv_heads(self) -> int:
return self.hf_text_config.num_attention_heads

def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU."""
"""Returns the number of KV heads per TP size."""
from sgl_jax.srt.utils.jax_utils import get_num_kv_heads_by_tp

total_num_kv_heads = self.get_total_num_kv_heads()
Expand Down
14 changes: 14 additions & 0 deletions python/sgl_jax/srt/constrained/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Constrained decoding with grammar backends."""

from sgl_jax.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)
from sgl_jax.srt.constrained.llguidance_backend import GuidanceBackend, GuidanceGrammar

__all__ = [
"BaseGrammarBackend",
"BaseGrammarObject",
"GuidanceBackend",
"GuidanceGrammar",
]
179 changes: 179 additions & 0 deletions python/sgl_jax/srt/constrained/base_grammar_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Base classes for grammar-constrained decoding backends."""

import concurrent.futures as futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any

import numpy as np

from sgl_jax.srt.server_args import ServerArgs

logger = logging.getLogger(__name__)


class BaseGrammarObject:
"""Base class for grammar objects that maintain state during generation."""

def __init__(self):
self.finished = False

def accept_token(self, token: int):
raise NotImplementedError()

def allocate_vocab_mask(self, vocab_size: int, batch_size: int):
raise NotImplementedError()

def fill_vocab_mask(self, vocab_mask: np.ndarray, idx: int):
raise NotImplementedError()

def is_terminated(self) -> bool:
raise NotImplementedError()

def copy(self):
raise NotImplementedError()


class BaseGrammarBackend:
"""Base class for grammar backends with async compilation support."""

def __init__(self, num_threads: int = 4):
"""Initialize the grammar backend.

Args:
num_threads: Number of threads for async grammar compilation.
"""
self.executor = ThreadPoolExecutor(max_workers=num_threads)
self.cache: dict[tuple[str, str], Any] = {}

def get_cached_or_future_value(self, key: tuple[str, str]) -> tuple[Any, bool]:
"""Get a cached grammar object or submit async compilation.

Args:
key: Tuple of (constraint_type, constraint_string)
e.g., ("json", schema_str) or ("regex", pattern)

Returns:
Tuple of (grammar_object or Future, cache_hit: bool)
"""
if key in self.cache:
value = self.cache[key]
# Check if it's a completed grammar or still a Future
if isinstance(value, futures.Future):
return value, False # Still compiling
else:
return value, True # Cache hit

# Not in cache, submit async compilation
key_type, key_string = key
future = self.executor.submit(self._dispatch, key_type, key_string)
self.cache[key] = future
return future, False

def set_cache(self, key: tuple[str, str], value: BaseGrammarObject):
"""Store a compiled grammar in the cache.

Args:
key: Cache key
value: Compiled grammar object
"""
self.cache[key] = value

def _dispatch(self, key_type: str, key_string: str) -> BaseGrammarObject:
"""Dispatch grammar creation based on type.

Args:
key_type: Type of constraint ("json", "regex", "ebnf", "structural_tag")
key_string: Constraint string (JSON schema, regex pattern, etc.)

Returns:
Compiled grammar object
"""
if key_type == "json":
return self.dispatch_json(key_string)
elif key_type == "regex":
return self.dispatch_regex(key_string)
elif key_type == "ebnf":
return self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string)
else:
raise ValueError(f"Unknown constraint type: {key_type}")

def dispatch_json(self, key_string: str) -> BaseGrammarObject:
"""Create a grammar from JSON schema.

Args:
key_string: JSON schema string

Returns:
Grammar object
"""
raise NotImplementedError()

def dispatch_regex(self, key_string: str) -> BaseGrammarObject:
"""Create a grammar from regex pattern.

Args:
key_string: Regex pattern string

Returns:
Grammar object
"""
raise NotImplementedError()

def dispatch_ebnf(self, key_string: str) -> BaseGrammarObject:
"""Create a grammar from EBNF definition.

Args:
key_string: EBNF grammar string

Returns:
Grammar object
"""
raise NotImplementedError()

def dispatch_structural_tag(self, key_string: str) -> BaseGrammarObject:
"""Create a grammar from structural tag configuration.

Args:
key_string: JSON string of structural tag config

Returns:
Grammar object
"""
raise NotImplementedError()

def shutdown(self):
"""Shutdown the thread pool executor."""
self.executor.shutdown(wait=False)


# Sentinel object for invalid/failed grammar compilation
INVALID_GRAMMAR_OBJ = BaseGrammarObject()


def create_grammar_backend(
server_args: ServerArgs,
tokenizer,
vocab_size: int,
eos_token_ids: set | None = None,
) -> BaseGrammarBackend | None:
name = server_args.grammar_backend

if name == "llguidance":
from sgl_jax.srt.constrained.llguidance_backend import GuidanceBackend

grammar_backend = GuidanceBackend(
tokenizer=tokenizer,
num_threads=4,
n_vocab=vocab_size,
any_whitespace=not server_args.constrained_json_disable_any_whitespace,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif name == "none":
return None
else:
raise ValueError(f"Invalid grammar backend: {name}")

return grammar_backend
106 changes: 106 additions & 0 deletions python/sgl_jax/srt/constrained/bitmask_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""JAX-based bitmask operations for vocabulary masking on TPU."""

import jax
import jax.numpy as jnp
import numpy as np
from llguidance import LLInterpreter


def allocate_token_bitmask(batch_size: int, vocab_size: int) -> np.ndarray:
"""Allocate a token bitmask array.

Args:
batch_size: Batch size
vocab_size: Vocabulary size

Returns:
Numpy array of shape [batch_size, vocab_size // 32] with dtype int32
"""
num_int32_per_vocab = (vocab_size + 31) // 32
return np.zeros((batch_size, num_int32_per_vocab), dtype=np.int32)


def fill_token_bitmask(
matcher: LLInterpreter,
vocab_mask: np.ndarray,
batch_idx: int,
):
"""Fill the bitmask for a specific batch index using llguidance matcher.

Args:
matcher: LLMatcher or LLInterpreter instance
vocab_mask: Bitmask array of shape [batch_size, vocab_size // 32], dtype=int32
batch_idx: Index in the batch to fill
"""
assert vocab_mask.dtype == np.int32, "Mask must be int32"
assert vocab_mask.ndim == 2, "Mask must be 2D"
v = vocab_mask[batch_idx, :]
matcher.unsafe_compute_mask_ptr(
v.ctypes.data,
v.nbytes,
)


@jax.jit
def apply_token_bitmask(
logits: jnp.ndarray,
vocab_mask: jnp.ndarray,
) -> jnp.ndarray:
"""Apply token bitmask to logits.

Sets logits to -inf where the bitmask bit is 0.

Args:
logits: Logits array of shape [batch_size, vocab_size]
vocab_mask: Packed bitmask array of shape [batch_size, vocab_size // 32]

Returns:
Masked logits array of shape [batch_size, vocab_size]
"""
if vocab_mask is None:
return logits

# Unpack the bitmask from int32 to bool (full length = num_int32 * 32)
unpacked_mask_full = unpack_bitmask(vocab_mask) # [Bmask, num_int32*32]
vocab_size = logits.shape[-1]
mask_len = unpacked_mask_full.shape[-1]

# Match vocab dimension statically: pad with False or crop as needed
if mask_len < vocab_size:
pad = vocab_size - mask_len
unpacked_mask = jnp.pad(
unpacked_mask_full,
((0, 0), (0, pad)),
mode="constant",
constant_values=False,
)
elif mask_len > vocab_size:
unpacked_mask = unpacked_mask_full[:, :vocab_size]
else:
unpacked_mask = unpacked_mask_full

# Apply mask: set logits to -inf where mask is False (broadcast batch if needed)
masked_logits = jnp.where(unpacked_mask, logits, -jnp.inf)
return masked_logits


def unpack_bitmask(vocab_mask: jnp.ndarray) -> jnp.ndarray:
"""Unpack int32 bitmask to boolean array (no dynamic slicing).

Args:
vocab_mask: Packed bitmask [batch_size, num_int32]

Returns:
Boolean mask [batch_size, num_int32 * 32]
"""
# For each int32, extract 32 bits
bit_indices = jnp.arange(32)[None, :] # [1, 32]

def unpack_batch_item(mask_row):
# mask_row: [num_int32]
bits = jnp.bitwise_and(mask_row[:, None], 1 << bit_indices) != 0 # [num_int32, 32]
return bits.reshape(-1) # [num_int32 * 32]

# Apply to all batch items
unpacked = jax.vmap(unpack_batch_item)(vocab_mask) # [batch, num_int32*32]
return unpacked
Loading