Skip to content

Conversation

@Kipsora
Copy link
Collaborator

@Kipsora Kipsora commented Oct 14, 2025

This PR adds support for the xAI Grok-2 model. However, currently the execution of the Grok model still suffers from OOM error and it seems that it is due to the model not being correctly sharded across multiple hosts/TPUs.

Since this PR contains OOM errors, we currently manually set the number of hidden layers to 1 and use dummy weights for fast development:

# TODO (chhzh123): remove this
config.num_hidden_layers = 1
print(f"config: {config}")

The following scripts can be used to execute the model directly on tpu-v6e-32:

pip3 install -e python
cd python/sgl_jax
python3 bench_one_batch.py \
  --model-path xai-org/grok-2 \
  --tokenizer-path Xenova/grok-1-tokenizer \
  --correct \
  --tp-size 32 \
  --mem-fraction-static 0.4 \
  --download-dir /mnt \
  --load-format dummy

Example Outputs:

...
TPU_24(process=6,(0,6,0,0)) {'num_allocs': 71, 'bytes_in_use': 12892489984, 'peak_bytes_in_use': 14976714112, 'largest_alloc_size': 2147483648, 'bytes_limit': 33550221312, 'bytes_reserved': 67371008, 'peak_bytes_reserved': 67371008, 'bytes_reservable_limit': 29253468032, 'largest_free_block_bytes': 16301195136}
TPU_25(process=6,(1,6,0,0)) {'num_allocs': 67, 'bytes_in_use': 12892451584, 'peak_bytes_in_use': 14976678272, 'largest_alloc_size': 2147483648, 'bytes_limit': 33550221312, 'bytes_reserved': 67371008, 'peak_bytes_reserved': 67371008, 'bytes_reservable_limit': 29253542016, 'largest_free_block_bytes': 16301269120}
TPU_28(process=6,(0,7,0,0)) {'num_allocs': 67, 'bytes_in_use': 12892451584, 'peak_bytes_in_use': 14976678272, 'largest_alloc_size': 2147483648, 'bytes_limit': 33550221312, 'bytes_reserved': 67371008, 'peak_bytes_reserved': 67371008, 'bytes_reservable_limit': 29253542016, 'largest_free_block_bytes': 16301269120}
TPU_29(process=6,(1,7,0,0)) {'num_allocs': 67, 'bytes_in_use': 12892451584, 'peak_bytes_in_use': 14976678272, 'largest_alloc_size': 2147483648, 'bytes_limit': 33550221312, 'bytes_reserved': 67371008, 'peak_bytes_reserved': 67371008, 'bytes_reservable_limit': 29253542016, 'largest_free_block_bytes': 16301269120}
...

Clearly, each device has 12892451584 / 1024 / 1024 / 1024 = 12GB memory in use, which however is about the same as the model's total memory size.

To reproduce this behavior on a smaller TPU machines like tpu-v6e-4, the following commands can be used instead (basically change --tp-size 32 to --tp-size 4):

pip3 install -e python
cd python/sgl_jax
python3 bench_one_batch.py \
  --model-path xai-org/grok-2 \
  --tokenizer-path Xenova/grok-1-tokenizer \
  --correct \
  --tp-size 4 \
  --mem-fraction-static 0.4 \
  --download-dir /mnt \
  --load-format dummy

You can also compare this with one TPU setups by setting --tp-size 1, and its outputs are:

TPU_0(process=0,(0,0,0,0)) {'num_allocs': 65, 'bytes_in_use': 14874966656, 'peak_bytes_in_use': 14874967168, 'largest_alloc_size': 2147483648, 'bytes_limit': 33550235648, 'bytes_reserved': 67371008, 'peak_bytes_reserved': 67371008, 'bytes_reservable_limit': 29253612416, 'largest_free_block_bytes': 16301339520}
TPU_1(process=0,(1,0,0,0)) {'num_allocs': 2, 'bytes_in_use': 32384, 'peak_bytes_in_use': 32384, 'largest_alloc_size': 30720, 'bytes_limit': 33550235648, 'bytes_reserved': 0, 'peak_bytes_reserved': 0, 'bytes_reservable_limit': 33550235648, 'largest_free_block_bytes': 33550203264}
TPU_2(process=0,(0,1,0,0)) {'num_allocs': 2, 'bytes_in_use': 32384, 'peak_bytes_in_use': 32384, 'largest_alloc_size': 30720, 'bytes_limit': 33550235648, 'bytes_reserved': 0, 'peak_bytes_reserved': 0, 'bytes_reservable_limit': 33550235648, 'largest_free_block_bytes': 33550203264}
TPU_3(process=0,(1,1,0,0)) {'num_allocs': 2, 'bytes_in_use': 32384, 'peak_bytes_in_use': 32384, 'largest_alloc_size': 30720, 'bytes_limit': 33550235648, 'bytes_reserved': 0, 'peak_bytes_reserved': 0, 'bytes_reservable_limit': 33550235648, 'largest_free_block_bytes': 33550203264}

We have TPU0 used 14874966656 / 1024 / 1024 / 1024 = 13.8GB and all the other TPUs are almost untouched. It is expected that when using tp_size > 1, we should have much less memory usage than tp_size=1

@Prayer3th

@Kipsora Kipsora requested a review from Prayer3th October 14, 2025 02:37
@gemini-code-assist
Copy link

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Kipsora Kipsora changed the title Partial support for xAI Grok with OOM errors Feat: Partial support for xAI Grok with OOM errors Oct 14, 2025
@Prayer3th
Copy link
Collaborator

I'm wondering if this TPU memory usage figure includes the KV cache size, since it might be affected by the mem-fraction-static parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants