diff --git a/.meta/mast/qwen3_1_7b_mast.yaml b/.meta/mast/qwen3_1_7b_mast.yaml index 2cec3d3e1..27f434def 100644 --- a/.meta/mast/qwen3_1_7b_mast.yaml +++ b/.meta/mast/qwen3_1_7b_mast.yaml @@ -9,6 +9,7 @@ max_res_tokens: 512 model: "/mnt/wsfuse/teamforge/hf/qwen3_1.7b" off_by_n: 1 # Off by one by default launcher: mast +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -37,7 +38,7 @@ policy: model: /mnt/wsfuse/teamforge/hf/qwen3_1.7b tensor_parallel_size: 1 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true @@ -68,7 +69,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 @@ -112,7 +113,7 @@ ref_model: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/.meta/mast/qwen3_32b_mast.yaml b/.meta/mast/qwen3_32b_mast.yaml index 6346d45d4..9a41b9f9f 100644 --- a/.meta/mast/qwen3_32b_mast.yaml +++ b/.meta/mast/qwen3_32b_mast.yaml @@ -9,6 +9,7 @@ max_res_tokens: 512 model: "/mnt/wsfuse/teamforge/hf/qwen3_32b" off_by_n: 1 # Off by one by default launcher: mast +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -37,7 +38,7 @@ policy: model: /mnt/wsfuse/teamforge/hf/qwen3_32b tensor_parallel_size: 2 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true @@ -67,7 +68,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 8 @@ -110,7 +111,7 @@ ref_model: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/.meta/mast/qwen3_4b_mast.yaml b/.meta/mast/qwen3_4b_mast.yaml index 4ae673244..88e6dbfc9 100644 --- a/.meta/mast/qwen3_4b_mast.yaml +++ b/.meta/mast/qwen3_4b_mast.yaml @@ -9,6 +9,7 @@ max_res_tokens: 512 model: "Qwen/Qwen3-4B" off_by_n: 1 # Off by one by default launcher: mast +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -37,7 +38,7 @@ policy: model: /mnt/wsfuse/teamforge/hf/qwen3_4b tensor_parallel_size: 2 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true @@ -68,7 +69,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 4 @@ -112,7 +113,7 @@ ref_model: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/apps/grpo/llama3_8b.yaml b/apps/grpo/llama3_8b.yaml index 0597bb238..6a887ebc3 100644 --- a/apps/grpo/llama3_8b.yaml +++ b/apps/grpo/llama3_8b.yaml @@ -8,6 +8,7 @@ max_req_tokens: 1024 max_res_tokens: 2048 model: "meta-llama/Meta-Llama-3.1-8B-Instruct" off_by_n: 1 # Off by one by default +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM # Observability configuration metric_logging: @@ -32,7 +33,7 @@ policy: model: ${model} tensor_parallel_size: 2 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} @@ -59,7 +60,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 @@ -100,7 +101,7 @@ ref_model: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index ebdd27787..b0c08a28c 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -8,6 +8,7 @@ max_req_tokens: 1024 max_res_tokens: 2048 model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM # Main loop configuration rollout_threads: 1 # Recommended to set equal to policy.num_replicas @@ -36,7 +37,7 @@ policy: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} @@ -63,7 +64,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 @@ -101,7 +102,7 @@ ref_model: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 0366b58a3..ee8ec96fb 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -9,6 +9,7 @@ max_req_tokens: 1024 max_res_tokens: 1024 model: "Qwen/Qwen3-32B" off_by_n: 1 # Off by one by default +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM provisioner: launcher: slurm @@ -39,7 +40,7 @@ policy: model: ${model} tensor_parallel_size: 4 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} @@ -66,7 +67,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 @@ -104,7 +105,7 @@ ref_model: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 6736b0b01..9bec6a541 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -8,6 +8,7 @@ max_req_tokens: 1024 max_res_tokens: 2048 model: "Qwen/Qwen3-8B" off_by_n: 1 # Off by one by default +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM # Observability configuration metric_logging: @@ -32,7 +33,7 @@ policy: model: ${model} tensor_parallel_size: 2 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} @@ -59,7 +60,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 @@ -100,7 +101,7 @@ ref_model: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/src/forge/util/config.py b/src/forge/util/config.py index 0315ca525..c93c4c575 100644 --- a/src/forge/util/config.py +++ b/src/forge/util/config.py @@ -18,6 +18,9 @@ # Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}} OmegaConf.register_new_resolver("sum", lambda *args: sum(args), replace=True) +# Add support for boolean negation, e.g. ${not:${compile}} +OmegaConf.register_new_resolver("not", lambda x: not x, replace=True) + def _has_component(node: Any) -> bool: """Check if a node has a _component_ field.""" diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml index 5e0bcf17f..80e408f03 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml @@ -5,6 +5,7 @@ max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default +compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM # Policy configuration @@ -13,7 +14,7 @@ policy: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} sampling_params: n: ${group_size} max_tokens: ${max_res_tokens} @@ -40,7 +41,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml index 1f05daab4..d4964baad 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml @@ -7,6 +7,7 @@ max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default +compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM # Policy configuration @@ -15,7 +16,7 @@ policy: model: ${model} tensor_parallel_size: 4 pipeline_parallel_size: 1 - enforce_eager: false + enforce_eager: ${not:${compile}} sampling_params: n: ${group_size} max_tokens: ${max_res_tokens} @@ -42,7 +43,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 diff --git a/tests/sandbox/weight_sync/qwen3_1_7b.yaml b/tests/sandbox/weight_sync/qwen3_1_7b.yaml index ea28a1471..e18589eaa 100644 --- a/tests/sandbox/weight_sync/qwen3_1_7b.yaml +++ b/tests/sandbox/weight_sync/qwen3_1_7b.yaml @@ -5,6 +5,7 @@ model: "Qwen/Qwen3-1.7B" local_batch_size: 4 max_req_tokens: 64 max_res_tokens: 64 +compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM metric_logging: console: @@ -16,7 +17,7 @@ policy: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 - enforce_eager: true + enforce_eager: ${not:${compile}} sampling_params: n: 1 max_tokens: 32 # Just for verification forward pass @@ -42,7 +43,7 @@ trainer: dtype: bfloat16 gc_freq: 1 compile: - enable: false + enable: ${compile} parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 # Single GPU, no FSDP