Skip to content

Commit 7ef752b

Browse files
committed
feat: support structured output
1 parent 262cd5b commit 7ef752b

27 files changed

+1450
-108
lines changed

docs/basic_usage/features/radix_cache.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Key components:
4949
- Uses min-heap to select leaf nodes by `last_access_time`
5050
- Only evicts nodes with `lock_ref == 0`
5151

52-
The radix tree is stored on CPU while KV data uses GPU memory pools.
52+
The radix tree is stored on CPU while KV data uses TPU memory pools.
5353

5454
## Usage
5555

docs/features/chunked_prefill.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The primary objectives of chunked prefill are:
1616

1717
Traditional LLM inference processes entire prefill requests as single units, which can cause issues with very long sequences:
1818
- **Prefill phase**: Processes entire input prompt in parallel, but long sequences can exceed memory limits
19-
- **Memory constraints**: Large prefill requests may not fit in available GPU memory or batch size limits
19+
- **Memory constraints**: Large prefill requests may not fit in available TPU memory or batch size limits
2020

2121
Chunked prefill addresses this by breaking large prefill requests into smaller chunks that can be processed sequentially, allowing the system to handle much longer input sequences without running out of memory.
2222

python/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"flax~=0.12.0",
1616
"huggingface-hub~=0.34.3",
1717
"jinja2~=3.1.6",
18+
"llguidance~=1.3.0",
1819
"modelscope~=1.28.2",
1920
"msgpack-python~=0.5.6",
2021
"numpy~=2.2.6",

python/sgl_jax/bench_one_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# Usage (correctness test):
1515
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
1616
17-
## Reference output (of the correctness test above, can be gpu dependent):
17+
## Reference output (of the correctness test above, can be tpu dependent):
1818
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
1919
2020
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],

python/sgl_jax/srt/configs/model_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_total_num_kv_heads(self) -> int:
220220
return self.hf_text_config.num_attention_heads
221221

222222
def get_num_kv_heads(self, tensor_parallel_size) -> int:
223-
"""Returns the number of KV heads per GPU."""
223+
"""Returns the number of KV heads per TP size."""
224224
from sgl_jax.srt.utils.jax_utils import get_num_kv_heads_by_tp
225225

226226
total_num_kv_heads = self.get_total_num_kv_heads()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Constrained decoding with grammar backends."""
2+
3+
from sgl_jax.srt.constrained.base_grammar_backend import (
4+
BaseGrammarBackend,
5+
BaseGrammarObject,
6+
)
7+
from sgl_jax.srt.constrained.llguidance_backend import GuidanceBackend, GuidanceGrammar
8+
9+
__all__ = [
10+
"BaseGrammarBackend",
11+
"BaseGrammarObject",
12+
"GuidanceBackend",
13+
"GuidanceGrammar",
14+
]
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""Base classes for grammar-constrained decoding backends."""
2+
3+
import concurrent.futures as futures
4+
import logging
5+
from concurrent.futures import ThreadPoolExecutor
6+
from typing import Any
7+
8+
import numpy as np
9+
10+
from sgl_jax.srt.server_args import ServerArgs
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class BaseGrammarObject:
16+
"""Base class for grammar objects that maintain state during generation."""
17+
18+
def __init__(self):
19+
self.finished = False
20+
21+
def accept_token(self, token: int):
22+
raise NotImplementedError()
23+
24+
def allocate_vocab_mask(self, vocab_size: int, batch_size: int):
25+
raise NotImplementedError()
26+
27+
def fill_vocab_mask(self, vocab_mask: np.ndarray, idx: int):
28+
raise NotImplementedError()
29+
30+
def is_terminated(self) -> bool:
31+
raise NotImplementedError()
32+
33+
def copy(self):
34+
raise NotImplementedError()
35+
36+
37+
class BaseGrammarBackend:
38+
"""Base class for grammar backends with async compilation support."""
39+
40+
def __init__(self, num_threads: int = 4):
41+
"""Initialize the grammar backend.
42+
43+
Args:
44+
num_threads: Number of threads for async grammar compilation.
45+
"""
46+
self.executor = ThreadPoolExecutor(max_workers=num_threads)
47+
self.cache: dict[tuple[str, str], Any] = {}
48+
49+
def get_cached_or_future_value(self, key: tuple[str, str]) -> tuple[Any, bool]:
50+
"""Get a cached grammar object or submit async compilation.
51+
52+
Args:
53+
key: Tuple of (constraint_type, constraint_string)
54+
e.g., ("json", schema_str) or ("regex", pattern)
55+
56+
Returns:
57+
Tuple of (grammar_object or Future, cache_hit: bool)
58+
"""
59+
if key in self.cache:
60+
value = self.cache[key]
61+
# Check if it's a completed grammar or still a Future
62+
if isinstance(value, futures.Future):
63+
return value, False # Still compiling
64+
else:
65+
return value, True # Cache hit
66+
67+
# Not in cache, submit async compilation
68+
key_type, key_string = key
69+
future = self.executor.submit(self._dispatch, key_type, key_string)
70+
self.cache[key] = future
71+
return future, False
72+
73+
def set_cache(self, key: tuple[str, str], value: BaseGrammarObject):
74+
"""Store a compiled grammar in the cache.
75+
76+
Args:
77+
key: Cache key
78+
value: Compiled grammar object
79+
"""
80+
self.cache[key] = value
81+
82+
def _dispatch(self, key_type: str, key_string: str) -> BaseGrammarObject:
83+
"""Dispatch grammar creation based on type.
84+
85+
Args:
86+
key_type: Type of constraint ("json", "regex", "ebnf", "structural_tag")
87+
key_string: Constraint string (JSON schema, regex pattern, etc.)
88+
89+
Returns:
90+
Compiled grammar object
91+
"""
92+
if key_type == "json":
93+
return self.dispatch_json(key_string)
94+
elif key_type == "regex":
95+
return self.dispatch_regex(key_string)
96+
elif key_type == "ebnf":
97+
return self.dispatch_ebnf(key_string)
98+
elif key_type == "structural_tag":
99+
return self.dispatch_structural_tag(key_string)
100+
else:
101+
raise ValueError(f"Unknown constraint type: {key_type}")
102+
103+
def dispatch_json(self, key_string: str) -> BaseGrammarObject:
104+
"""Create a grammar from JSON schema.
105+
106+
Args:
107+
key_string: JSON schema string
108+
109+
Returns:
110+
Grammar object
111+
"""
112+
raise NotImplementedError()
113+
114+
def dispatch_regex(self, key_string: str) -> BaseGrammarObject:
115+
"""Create a grammar from regex pattern.
116+
117+
Args:
118+
key_string: Regex pattern string
119+
120+
Returns:
121+
Grammar object
122+
"""
123+
raise NotImplementedError()
124+
125+
def dispatch_ebnf(self, key_string: str) -> BaseGrammarObject:
126+
"""Create a grammar from EBNF definition.
127+
128+
Args:
129+
key_string: EBNF grammar string
130+
131+
Returns:
132+
Grammar object
133+
"""
134+
raise NotImplementedError()
135+
136+
def dispatch_structural_tag(self, key_string: str) -> BaseGrammarObject:
137+
"""Create a grammar from structural tag configuration.
138+
139+
Args:
140+
key_string: JSON string of structural tag config
141+
142+
Returns:
143+
Grammar object
144+
"""
145+
raise NotImplementedError()
146+
147+
def shutdown(self):
148+
"""Shutdown the thread pool executor."""
149+
self.executor.shutdown(wait=False)
150+
151+
152+
# Sentinel object for invalid/failed grammar compilation
153+
INVALID_GRAMMAR_OBJ = BaseGrammarObject()
154+
155+
156+
def create_grammar_backend(
157+
server_args: ServerArgs,
158+
tokenizer,
159+
vocab_size: int,
160+
eos_token_ids: set | None = None,
161+
) -> BaseGrammarBackend | None:
162+
name = server_args.grammar_backend
163+
164+
if name == "llguidance":
165+
from sgl_jax.srt.constrained.llguidance_backend import GuidanceBackend
166+
167+
grammar_backend = GuidanceBackend(
168+
tokenizer=tokenizer,
169+
num_threads=4,
170+
n_vocab=vocab_size,
171+
any_whitespace=not server_args.constrained_json_disable_any_whitespace,
172+
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
173+
)
174+
elif name == "none":
175+
return None
176+
else:
177+
raise ValueError(f"Invalid grammar backend: {name}")
178+
179+
return grammar_backend
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""JAX-based bitmask operations for vocabulary masking on TPU."""
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import numpy as np
6+
from llguidance import LLInterpreter
7+
8+
9+
def allocate_token_bitmask(batch_size: int, vocab_size: int) -> np.ndarray:
10+
"""Allocate a token bitmask array.
11+
12+
Args:
13+
batch_size: Batch size
14+
vocab_size: Vocabulary size
15+
16+
Returns:
17+
Numpy array of shape [batch_size, vocab_size // 32] with dtype int32
18+
"""
19+
num_int32_per_vocab = (vocab_size + 31) // 32
20+
return np.zeros((batch_size, num_int32_per_vocab), dtype=np.int32)
21+
22+
23+
def fill_token_bitmask(
24+
matcher: LLInterpreter,
25+
vocab_mask: np.ndarray,
26+
batch_idx: int,
27+
):
28+
"""Fill the bitmask for a specific batch index using llguidance matcher.
29+
30+
Args:
31+
matcher: LLMatcher or LLInterpreter instance
32+
vocab_mask: Bitmask array of shape [batch_size, vocab_size // 32], dtype=int32
33+
batch_idx: Index in the batch to fill
34+
"""
35+
assert vocab_mask.dtype == np.int32, "Mask must be int32"
36+
assert vocab_mask.ndim == 2, "Mask must be 2D"
37+
v = vocab_mask[batch_idx, :]
38+
matcher.unsafe_compute_mask_ptr(
39+
v.ctypes.data,
40+
v.nbytes,
41+
)
42+
43+
44+
@jax.jit
45+
def apply_token_bitmask(
46+
logits: jnp.ndarray,
47+
vocab_mask: jnp.ndarray,
48+
) -> jnp.ndarray:
49+
"""Apply token bitmask to logits.
50+
51+
Sets logits to -inf where the bitmask bit is 0.
52+
53+
Args:
54+
logits: Logits array of shape [batch_size, vocab_size]
55+
vocab_mask: Packed bitmask array of shape [batch_size, vocab_size // 32]
56+
57+
Returns:
58+
Masked logits array of shape [batch_size, vocab_size]
59+
"""
60+
if vocab_mask is None:
61+
return logits
62+
63+
# Unpack the bitmask from int32 to bool (full length = num_int32 * 32)
64+
unpacked_mask_full = unpack_bitmask(vocab_mask) # [Bmask, num_int32*32]
65+
vocab_size = logits.shape[-1]
66+
mask_len = unpacked_mask_full.shape[-1]
67+
68+
# Match vocab dimension statically: pad with False or crop as needed
69+
if mask_len < vocab_size:
70+
pad = vocab_size - mask_len
71+
unpacked_mask = jnp.pad(
72+
unpacked_mask_full,
73+
((0, 0), (0, pad)),
74+
mode="constant",
75+
constant_values=False,
76+
)
77+
elif mask_len > vocab_size:
78+
unpacked_mask = unpacked_mask_full[:, :vocab_size]
79+
else:
80+
unpacked_mask = unpacked_mask_full
81+
82+
# Apply mask: set logits to -inf where mask is False (broadcast batch if needed)
83+
masked_logits = jnp.where(unpacked_mask, logits, -jnp.inf)
84+
return masked_logits
85+
86+
87+
def unpack_bitmask(vocab_mask: jnp.ndarray) -> jnp.ndarray:
88+
"""Unpack int32 bitmask to boolean array (no dynamic slicing).
89+
90+
Args:
91+
vocab_mask: Packed bitmask [batch_size, num_int32]
92+
93+
Returns:
94+
Boolean mask [batch_size, num_int32 * 32]
95+
"""
96+
# For each int32, extract 32 bits
97+
bit_indices = jnp.arange(32)[None, :] # [1, 32]
98+
99+
def unpack_batch_item(mask_row):
100+
# mask_row: [num_int32]
101+
bits = jnp.bitwise_and(mask_row[:, None], 1 << bit_indices) != 0 # [num_int32, 32]
102+
return bits.reshape(-1) # [num_int32 * 32]
103+
104+
# Apply to all batch items
105+
unpacked = jax.vmap(unpack_batch_item)(vocab_mask) # [batch, num_int32*32]
106+
return unpacked

0 commit comments

Comments
 (0)