Skip to content

Commit 7b8580a

Browse files
felipemello1Felipe Mello
andauthored
easy - add compile flag to configs (#634)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 9bce287 commit 7b8580a

File tree

11 files changed

+40
-27
lines changed

11 files changed

+40
-27
lines changed

.meta/mast/qwen3_1_7b_mast.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ max_res_tokens: 512
99
model: "/mnt/wsfuse/teamforge/hf/qwen3_1.7b"
1010
off_by_n: 1 # Off by one by default
1111
launcher: mast
12+
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213

1314
# Main loop configuration
1415
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -37,7 +38,7 @@ policy:
3738
model: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
3839
tensor_parallel_size: 1
3940
pipeline_parallel_size: 1
40-
enforce_eager: false
41+
enforce_eager: ${not:${compile}}
4142
# TODO: Had to disable this becasue vLLm wouldn't like
4243
# needs to revisited.
4344
disable_custom_all_reduce: true
@@ -68,7 +69,7 @@ trainer:
6869
dtype: bfloat16
6970
gc_freq: 1
7071
compile:
71-
enable: false
72+
enable: ${compile}
7273
parallelism:
7374
data_parallel_replicate_degree: 1
7475
data_parallel_shard_degree: 1
@@ -112,7 +113,7 @@ ref_model:
112113
dtype: bfloat16
113114
gc_freq: 1
114115
compile:
115-
enable: false
116+
enable: ${compile}
116117
parallelism:
117118
data_parallel_replicate_degree: 1
118119
data_parallel_shard_degree: 1

.meta/mast/qwen3_32b_mast.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ max_res_tokens: 512
99
model: "/mnt/wsfuse/teamforge/hf/qwen3_32b"
1010
off_by_n: 1 # Off by one by default
1111
launcher: mast
12+
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213

1314
# Main loop configuration
1415
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -37,7 +38,7 @@ policy:
3738
model: /mnt/wsfuse/teamforge/hf/qwen3_32b
3839
tensor_parallel_size: 2
3940
pipeline_parallel_size: 1
40-
enforce_eager: false
41+
enforce_eager: ${not:${compile}}
4142
# TODO: Had to disable this becasue vLLm wouldn't like
4243
# needs to revisited.
4344
disable_custom_all_reduce: true
@@ -67,7 +68,7 @@ trainer:
6768
dtype: bfloat16
6869
gc_freq: 1
6970
compile:
70-
enable: false
71+
enable: ${compile}
7172
parallelism:
7273
data_parallel_replicate_degree: 1
7374
data_parallel_shard_degree: 8
@@ -110,7 +111,7 @@ ref_model:
110111
dtype: bfloat16
111112
gc_freq: 1
112113
compile:
113-
enable: false
114+
enable: ${compile}
114115
parallelism:
115116
data_parallel_replicate_degree: 1
116117
data_parallel_shard_degree: 1

.meta/mast/qwen3_4b_mast.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ max_res_tokens: 512
99
model: "Qwen/Qwen3-4B"
1010
off_by_n: 1 # Off by one by default
1111
launcher: mast
12+
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213

1314
# Main loop configuration
1415
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -37,7 +38,7 @@ policy:
3738
model: /mnt/wsfuse/teamforge/hf/qwen3_4b
3839
tensor_parallel_size: 2
3940
pipeline_parallel_size: 1
40-
enforce_eager: false
41+
enforce_eager: ${not:${compile}}
4142
# TODO: Had to disable this becasue vLLm wouldn't like
4243
# needs to revisited.
4344
disable_custom_all_reduce: true
@@ -68,7 +69,7 @@ trainer:
6869
dtype: bfloat16
6970
gc_freq: 1
7071
compile:
71-
enable: false
72+
enable: ${compile}
7273
parallelism:
7374
data_parallel_replicate_degree: 1
7475
data_parallel_shard_degree: 4
@@ -112,7 +113,7 @@ ref_model:
112113
dtype: bfloat16
113114
gc_freq: 1
114115
compile:
115-
enable: false
116+
enable: ${compile}
116117
parallelism:
117118
data_parallel_replicate_degree: 1
118119
data_parallel_shard_degree: 1

apps/grpo/llama3_8b.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ max_req_tokens: 1024
88
max_res_tokens: 2048
99
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
1010
off_by_n: 1 # Off by one by default
11+
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1112

1213
# Observability configuration
1314
metric_logging:
@@ -32,7 +33,7 @@ policy:
3233
model: ${model}
3334
tensor_parallel_size: 2
3435
pipeline_parallel_size: 1
35-
enforce_eager: false
36+
enforce_eager: ${not:${compile}}
3637
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
3738
n: ${group_size}
3839
max_tokens: ${max_res_tokens}
@@ -59,7 +60,7 @@ trainer:
5960
dtype: bfloat16
6061
gc_freq: 1
6162
compile:
62-
enable: false
63+
enable: ${compile}
6364
parallelism:
6465
data_parallel_replicate_degree: 1
6566
data_parallel_shard_degree: -1
@@ -100,7 +101,7 @@ ref_model:
100101
dtype: bfloat16
101102
gc_freq: 1
102103
compile:
103-
enable: false
104+
enable: ${compile}
104105
parallelism:
105106
data_parallel_replicate_degree: 1
106107
data_parallel_shard_degree: 1

apps/grpo/qwen3_1_7b.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ max_req_tokens: 1024
88
max_res_tokens: 2048
99
model: "Qwen/Qwen3-1.7B"
1010
off_by_n: 1 # Off by one by default
11+
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1112

1213
# Main loop configuration
1314
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
@@ -36,7 +37,7 @@ policy:
3637
model: ${model}
3738
tensor_parallel_size: 1
3839
pipeline_parallel_size: 1
39-
enforce_eager: false
40+
enforce_eager: ${not:${compile}}
4041
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
4142
n: ${group_size}
4243
max_tokens: ${max_res_tokens}
@@ -63,7 +64,7 @@ trainer:
6364
dtype: bfloat16
6465
gc_freq: 1
6566
compile:
66-
enable: false
67+
enable: ${compile}
6768
parallelism:
6869
data_parallel_replicate_degree: 1
6970
data_parallel_shard_degree: 1
@@ -101,7 +102,7 @@ ref_model:
101102
dtype: bfloat16
102103
gc_freq: 1
103104
compile:
104-
enable: false
105+
enable: ${compile}
105106
parallelism:
106107
data_parallel_replicate_degree: 1
107108
data_parallel_shard_degree: 1

apps/grpo/qwen3_32b.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ max_req_tokens: 1024
99
max_res_tokens: 1024
1010
model: "Qwen/Qwen3-32B"
1111
off_by_n: 1 # Off by one by default
12+
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1213

1314
provisioner:
1415
launcher: slurm
@@ -39,7 +40,7 @@ policy:
3940
model: ${model}
4041
tensor_parallel_size: 4
4142
pipeline_parallel_size: 1
42-
enforce_eager: false
43+
enforce_eager: ${not:${compile}}
4344
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
4445
n: ${group_size}
4546
max_tokens: ${max_res_tokens}
@@ -66,7 +67,7 @@ trainer:
6667
dtype: bfloat16
6768
gc_freq: 1
6869
compile:
69-
enable: false
70+
enable: ${compile}
7071
parallelism:
7172
data_parallel_replicate_degree: 1
7273
data_parallel_shard_degree: 1
@@ -104,7 +105,7 @@ ref_model:
104105
dtype: bfloat16
105106
gc_freq: 1
106107
compile:
107-
enable: false
108+
enable: ${compile}
108109
parallelism:
109110
data_parallel_replicate_degree: 1
110111
data_parallel_shard_degree: 1

apps/grpo/qwen3_8b.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ max_req_tokens: 1024
88
max_res_tokens: 2048
99
model: "Qwen/Qwen3-8B"
1010
off_by_n: 1 # Off by one by default
11+
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1112

1213
# Observability configuration
1314
metric_logging:
@@ -32,7 +33,7 @@ policy:
3233
model: ${model}
3334
tensor_parallel_size: 2
3435
pipeline_parallel_size: 1
35-
enforce_eager: false
36+
enforce_eager: ${not:${compile}}
3637
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
3738
n: ${group_size}
3839
max_tokens: ${max_res_tokens}
@@ -59,7 +60,7 @@ trainer:
5960
dtype: bfloat16
6061
gc_freq: 1
6162
compile:
62-
enable: false
63+
enable: ${compile}
6364
parallelism:
6465
data_parallel_replicate_degree: 1
6566
data_parallel_shard_degree: -1
@@ -100,7 +101,7 @@ ref_model:
100101
dtype: bfloat16
101102
gc_freq: 1
102103
compile:
103-
enable: false
104+
enable: ${compile}
104105
parallelism:
105106
data_parallel_replicate_degree: 1
106107
data_parallel_shard_degree: 1

src/forge/util/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
# Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}}
1919
OmegaConf.register_new_resolver("sum", lambda *args: sum(args), replace=True)
2020

21+
# Add support for boolean negation, e.g. ${not:${compile}}
22+
OmegaConf.register_new_resolver("not", lambda x: not x, replace=True)
23+
2124

2225
def _has_component(node: Any) -> bool:
2326
"""Check if a node has a _component_ field."""

tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ max_req_tokens: 512
55
max_res_tokens: 512
66
model: "Qwen/Qwen3-1.7B"
77
off_by_n: 1 # Off by one by default
8+
compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM
89

910

1011
# Policy configuration
@@ -13,7 +14,7 @@ policy:
1314
model: ${model}
1415
tensor_parallel_size: 1
1516
pipeline_parallel_size: 1
16-
enforce_eager: false
17+
enforce_eager: ${not:${compile}}
1718
sampling_params:
1819
n: ${group_size}
1920
max_tokens: ${max_res_tokens}
@@ -40,7 +41,7 @@ trainer:
4041
dtype: bfloat16
4142
gc_freq: 1
4243
compile:
43-
enable: false
44+
enable: ${compile}
4445
parallelism:
4546
data_parallel_replicate_degree: 1
4647
data_parallel_shard_degree: 1

tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ max_req_tokens: 512
77
max_res_tokens: 512
88
model: "Qwen/Qwen3-1.7B"
99
off_by_n: 1 # Off by one by default
10+
compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM
1011

1112

1213
# Policy configuration
@@ -15,7 +16,7 @@ policy:
1516
model: ${model}
1617
tensor_parallel_size: 4
1718
pipeline_parallel_size: 1
18-
enforce_eager: false
19+
enforce_eager: ${not:${compile}}
1920
sampling_params:
2021
n: ${group_size}
2122
max_tokens: ${max_res_tokens}
@@ -42,7 +43,7 @@ trainer:
4243
dtype: bfloat16
4344
gc_freq: 1
4445
compile:
45-
enable: false
46+
enable: ${compile}
4647
parallelism:
4748
data_parallel_replicate_degree: 1
4849
data_parallel_shard_degree: 1

0 commit comments

Comments
 (0)