File tree Expand file tree Collapse file tree 11 files changed +40
-27
lines changed
integration_tests/fixtures Expand file tree Collapse file tree 11 files changed +40
-27
lines changed Original file line number Diff line number Diff line change @@ -9,6 +9,7 @@ max_res_tokens: 512
99model : " /mnt/wsfuse/teamforge/hf/qwen3_1.7b"
1010off_by_n : 1 # Off by one by default
1111launcher : mast
12+ compile : true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213
1314# Main loop configuration
1415rollout_threads : ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -37,7 +38,7 @@ policy:
3738 model : /mnt/wsfuse/teamforge/hf/qwen3_1.7b
3839 tensor_parallel_size : 1
3940 pipeline_parallel_size : 1
40- enforce_eager : false
41+ enforce_eager : ${not:${compile}}
4142 # TODO: Had to disable this becasue vLLm wouldn't like
4243 # needs to revisited.
4344 disable_custom_all_reduce : true
@@ -68,7 +69,7 @@ trainer:
6869 dtype : bfloat16
6970 gc_freq : 1
7071 compile :
71- enable : false
72+ enable : ${compile}
7273 parallelism :
7374 data_parallel_replicate_degree : 1
7475 data_parallel_shard_degree : 1
@@ -112,7 +113,7 @@ ref_model:
112113 dtype : bfloat16
113114 gc_freq : 1
114115 compile :
115- enable : false
116+ enable : ${compile}
116117 parallelism :
117118 data_parallel_replicate_degree : 1
118119 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change @@ -9,6 +9,7 @@ max_res_tokens: 512
99model : " /mnt/wsfuse/teamforge/hf/qwen3_32b"
1010off_by_n : 1 # Off by one by default
1111launcher : mast
12+ compile : true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213
1314# Main loop configuration
1415rollout_threads : ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -37,7 +38,7 @@ policy:
3738 model : /mnt/wsfuse/teamforge/hf/qwen3_32b
3839 tensor_parallel_size : 2
3940 pipeline_parallel_size : 1
40- enforce_eager : false
41+ enforce_eager : ${not:${compile}}
4142 # TODO: Had to disable this becasue vLLm wouldn't like
4243 # needs to revisited.
4344 disable_custom_all_reduce : true
@@ -67,7 +68,7 @@ trainer:
6768 dtype : bfloat16
6869 gc_freq : 1
6970 compile :
70- enable : false
71+ enable : ${compile}
7172 parallelism :
7273 data_parallel_replicate_degree : 1
7374 data_parallel_shard_degree : 8
@@ -110,7 +111,7 @@ ref_model:
110111 dtype : bfloat16
111112 gc_freq : 1
112113 compile :
113- enable : false
114+ enable : ${compile}
114115 parallelism :
115116 data_parallel_replicate_degree : 1
116117 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change @@ -9,6 +9,7 @@ max_res_tokens: 512
99model : " Qwen/Qwen3-4B"
1010off_by_n : 1 # Off by one by default
1111launcher : mast
12+ compile : true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213
1314# Main loop configuration
1415rollout_threads : ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -37,7 +38,7 @@ policy:
3738 model : /mnt/wsfuse/teamforge/hf/qwen3_4b
3839 tensor_parallel_size : 2
3940 pipeline_parallel_size : 1
40- enforce_eager : false
41+ enforce_eager : ${not:${compile}}
4142 # TODO: Had to disable this becasue vLLm wouldn't like
4243 # needs to revisited.
4344 disable_custom_all_reduce : true
@@ -68,7 +69,7 @@ trainer:
6869 dtype : bfloat16
6970 gc_freq : 1
7071 compile :
71- enable : false
72+ enable : ${compile}
7273 parallelism :
7374 data_parallel_replicate_degree : 1
7475 data_parallel_shard_degree : 4
@@ -112,7 +113,7 @@ ref_model:
112113 dtype : bfloat16
113114 gc_freq : 1
114115 compile :
115- enable : false
116+ enable : ${compile}
116117 parallelism :
117118 data_parallel_replicate_degree : 1
118119 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change @@ -8,6 +8,7 @@ max_req_tokens: 1024
88max_res_tokens : 2048
99model : " meta-llama/Meta-Llama-3.1-8B-Instruct"
1010off_by_n : 1 # Off by one by default
11+ compile : true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1112
1213# Observability configuration
1314metric_logging :
@@ -32,7 +33,7 @@ policy:
3233 model : ${model}
3334 tensor_parallel_size : 2
3435 pipeline_parallel_size : 1
35- enforce_eager : false
36+ enforce_eager : ${not:${compile}}
3637 sampling_params : # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
3738 n : ${group_size}
3839 max_tokens : ${max_res_tokens}
@@ -59,7 +60,7 @@ trainer:
5960 dtype : bfloat16
6061 gc_freq : 1
6162 compile :
62- enable : false
63+ enable : ${compile}
6364 parallelism :
6465 data_parallel_replicate_degree : 1
6566 data_parallel_shard_degree : -1
@@ -100,7 +101,7 @@ ref_model:
100101 dtype : bfloat16
101102 gc_freq : 1
102103 compile :
103- enable : false
104+ enable : ${compile}
104105 parallelism :
105106 data_parallel_replicate_degree : 1
106107 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change @@ -8,6 +8,7 @@ max_req_tokens: 1024
88max_res_tokens : 2048
99model : " Qwen/Qwen3-1.7B"
1010off_by_n : 1 # Off by one by default
11+ compile : true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1112
1213# Main loop configuration
1314rollout_threads : 1 # Recommended to set equal to policy.num_replicas
@@ -36,7 +37,7 @@ policy:
3637 model : ${model}
3738 tensor_parallel_size : 1
3839 pipeline_parallel_size : 1
39- enforce_eager : false
40+ enforce_eager : ${not:${compile}}
4041 sampling_params : # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
4142 n : ${group_size}
4243 max_tokens : ${max_res_tokens}
@@ -63,7 +64,7 @@ trainer:
6364 dtype : bfloat16
6465 gc_freq : 1
6566 compile :
66- enable : false
67+ enable : ${compile}
6768 parallelism :
6869 data_parallel_replicate_degree : 1
6970 data_parallel_shard_degree : 1
@@ -101,7 +102,7 @@ ref_model:
101102 dtype : bfloat16
102103 gc_freq : 1
103104 compile :
104- enable : false
105+ enable : ${compile}
105106 parallelism :
106107 data_parallel_replicate_degree : 1
107108 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change @@ -9,6 +9,7 @@ max_req_tokens: 1024
99max_res_tokens : 1024
1010model : " Qwen/Qwen3-32B"
1111off_by_n : 1 # Off by one by default
12+ compile : true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213
1314provisioner :
1415 launcher : slurm
@@ -39,7 +40,7 @@ policy:
3940 model : ${model}
4041 tensor_parallel_size : 4
4142 pipeline_parallel_size : 1
42- enforce_eager : false
43+ enforce_eager : ${not:${compile}}
4344 sampling_params : # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
4445 n : ${group_size}
4546 max_tokens : ${max_res_tokens}
@@ -66,7 +67,7 @@ trainer:
6667 dtype : bfloat16
6768 gc_freq : 1
6869 compile :
69- enable : false
70+ enable : ${compile}
7071 parallelism :
7172 data_parallel_replicate_degree : 1
7273 data_parallel_shard_degree : 1
@@ -104,7 +105,7 @@ ref_model:
104105 dtype : bfloat16
105106 gc_freq : 1
106107 compile :
107- enable : false
108+ enable : ${compile}
108109 parallelism :
109110 data_parallel_replicate_degree : 1
110111 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change @@ -8,6 +8,7 @@ max_req_tokens: 1024
88max_res_tokens : 2048
99model : " Qwen/Qwen3-8B"
1010off_by_n : 1 # Off by one by default
11+ compile : true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1112
1213# Observability configuration
1314metric_logging :
@@ -32,7 +33,7 @@ policy:
3233 model : ${model}
3334 tensor_parallel_size : 2
3435 pipeline_parallel_size : 1
35- enforce_eager : false
36+ enforce_eager : ${not:${compile}}
3637 sampling_params : # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
3738 n : ${group_size}
3839 max_tokens : ${max_res_tokens}
@@ -59,7 +60,7 @@ trainer:
5960 dtype : bfloat16
6061 gc_freq : 1
6162 compile :
62- enable : false
63+ enable : ${compile}
6364 parallelism :
6465 data_parallel_replicate_degree : 1
6566 data_parallel_shard_degree : -1
@@ -100,7 +101,7 @@ ref_model:
100101 dtype : bfloat16
101102 gc_freq : 1
102103 compile :
103- enable : false
104+ enable : ${compile}
104105 parallelism :
105106 data_parallel_replicate_degree : 1
106107 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change 1818# Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}}
1919OmegaConf .register_new_resolver ("sum" , lambda * args : sum (args ), replace = True )
2020
21+ # Add support for boolean negation, e.g. ${not:${compile}}
22+ OmegaConf .register_new_resolver ("not" , lambda x : not x , replace = True )
23+
2124
2225def _has_component (node : Any ) -> bool :
2326 """Check if a node has a _component_ field."""
Original file line number Diff line number Diff line change @@ -5,6 +5,7 @@ max_req_tokens: 512
55max_res_tokens : 512
66model : " Qwen/Qwen3-1.7B"
77off_by_n : 1 # Off by one by default
8+ compile : true # Enable torch.compile for trainer, and CUDA graphs for vLLM
89
910
1011# Policy configuration
@@ -13,7 +14,7 @@ policy:
1314 model : ${model}
1415 tensor_parallel_size : 1
1516 pipeline_parallel_size : 1
16- enforce_eager : false
17+ enforce_eager : ${not:${compile}}
1718 sampling_params :
1819 n : ${group_size}
1920 max_tokens : ${max_res_tokens}
@@ -40,7 +41,7 @@ trainer:
4041 dtype : bfloat16
4142 gc_freq : 1
4243 compile :
43- enable : false
44+ enable : ${compile}
4445 parallelism :
4546 data_parallel_replicate_degree : 1
4647 data_parallel_shard_degree : 1
Original file line number Diff line number Diff line change @@ -7,6 +7,7 @@ max_req_tokens: 512
77max_res_tokens : 512
88model : " Qwen/Qwen3-1.7B"
99off_by_n : 1 # Off by one by default
10+ compile : true # Enable torch.compile for trainer, and CUDA graphs for vLLM
1011
1112
1213# Policy configuration
@@ -15,7 +16,7 @@ policy:
1516 model : ${model}
1617 tensor_parallel_size : 4
1718 pipeline_parallel_size : 1
18- enforce_eager : false
19+ enforce_eager : ${not:${compile}}
1920 sampling_params :
2021 n : ${group_size}
2122 max_tokens : ${max_res_tokens}
@@ -42,7 +43,7 @@ trainer:
4243 dtype : bfloat16
4344 gc_freq : 1
4445 compile :
45- enable : false
46+ enable : ${compile}
4647 parallelism :
4748 data_parallel_replicate_degree : 1
4849 data_parallel_shard_degree : 1
You can’t perform that action at this time.
0 commit comments