Feat: Partial support for xAI Grok with OOM errors #239
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds support for the xAI Grok-2 model. However, currently the execution of the Grok model still suffers from OOM error and it seems that it is due to the model not being correctly sharded across multiple hosts/TPUs.
Since this PR contains OOM errors, we currently manually set the number of hidden layers to 1 and use dummy weights for fast development:
sglang-jax/python/sgl_jax/srt/models/grok.py
Lines 589 to 591 in 7edb5d6
The following scripts can be used to execute the model directly on
tpu-v6e-32:pip3 install -e python cd python/sgl_jax python3 bench_one_batch.py \ --model-path xai-org/grok-2 \ --tokenizer-path Xenova/grok-1-tokenizer \ --correct \ --tp-size 32 \ --mem-fraction-static 0.4 \ --download-dir /mnt \ --load-format dummyExample Outputs:
Clearly, each device has
12892451584 / 1024 / 1024 / 1024 = 12GBmemory in use, which however is about the same as the model's total memory size.To reproduce this behavior on a smaller TPU machines like
tpu-v6e-4, the following commands can be used instead (basically change--tp-size 32to--tp-size 4):pip3 install -e python cd python/sgl_jax python3 bench_one_batch.py \ --model-path xai-org/grok-2 \ --tokenizer-path Xenova/grok-1-tokenizer \ --correct \ --tp-size 4 \ --mem-fraction-static 0.4 \ --download-dir /mnt \ --load-format dummyYou can also compare this with one TPU setups by setting
--tp-size 1, and its outputs are:We have TPU0 used
14874966656 / 1024 / 1024 / 1024 = 13.8GBand all the other TPUs are almost untouched. It is expected that when usingtp_size > 1, we should have much less memory usage thantp_size=1@Prayer3th