Skip to content

Commit 38cf2df

Browse files
committed
fix bugs in model runner
1 parent 4bc7b94 commit 38cf2df

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,16 @@ def run_model_wrapper(forward_batch, logits_metadata):
186186
logits_metadata,
187187
)
188188

189+
def run_sampler_wrapper(*args):
190+
return jitted_sampler(
191+
sampler_def,
192+
sampler_state_def,
193+
sampler_state_leaves,
194+
*args,
195+
)
196+
189197
self.jitted_run_model = run_model_wrapper
190-
self.jitted_sampler = partial(
191-
jitted_sampler, self._sampler_def, self._sampler_state_def
192-
)
198+
self.jitted_sampler = run_sampler_wrapper
193199

194200
def get_available_device_memory(self):
195201
min_available_device_memory = get_available_device_memory(

0 commit comments

Comments
 (0)