diff --git a/crates/cuda_builder/src/lib.rs b/crates/cuda_builder/src/lib.rs index 6441b827..81908cdf 100644 --- a/crates/cuda_builder/src/lib.rs +++ b/crates/cuda_builder/src/lib.rs @@ -194,6 +194,10 @@ pub struct CudaBuilder { /// An optional path where to dump LLVM IR of the final output the codegen will feed to libnvvm. Usually /// used for debugging. pub final_module_path: Option, + /// The threshold for LLVM's loop unrolling optimization pass. Higher values allow more + /// aggressive unrolling, which can improve performance but increases code size. + /// When `None`, LLVM uses its default threshold. + pub unroll_threshold: Option, } impl CudaBuilder { @@ -216,6 +220,7 @@ impl CudaBuilder { debug: DebugInfo::None, build_args: vec![], final_module_path: None, + unroll_threshold: None, } } @@ -351,6 +356,13 @@ impl CudaBuilder { self } + /// Sets the threshold for LLVM's loop unrolling optimization pass. Higher values allow more + /// aggressive unrolling, which can improve performance but increases code size. + pub fn unroll_threshold(mut self, threshold: u32) -> Self { + self.unroll_threshold = Some(threshold); + self + } + /// Runs rustc to build the codegen and codegens the gpu crate, returning the path of the final /// ptx file. If [`ptx_file_copy_path`](Self::ptx_file_copy_path) is set, this returns the copied path. pub fn build(self) -> Result { @@ -748,6 +760,10 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result { llvm_args.push(path.to_str().unwrap().to_string()); } + if let Some(threshold) = builder.unroll_threshold { + llvm_args.push(format!("-unroll-threshold={threshold}")); + } + if builder.debug != DebugInfo::None { let (nvvm_flag, rustc_flag) = builder.debug.into_nvvm_and_rustc_options(); llvm_args.push(nvvm_flag); diff --git a/crates/rustc_codegen_nvvm/src/context.rs b/crates/rustc_codegen_nvvm/src/context.rs index 1a5582df..c4d89dc6 100644 --- a/crates/rustc_codegen_nvvm/src/context.rs +++ b/crates/rustc_codegen_nvvm/src/context.rs @@ -654,6 +654,7 @@ pub struct CodegenArgs { pub use_constant_memory_space: bool, pub final_module_path: Option, pub disassemble: Option, + pub unroll_threshold: Option, } impl CodegenArgs { @@ -712,6 +713,11 @@ impl CodegenArgs { skip_next = true; } else if let Some(entry) = arg.strip_prefix("--disassemble-entry=") { cg_args.disassemble = Some(DisassembleMode::Entry(entry.to_string())); + } else if let Some(threshold) = arg.strip_prefix("-unroll-threshold=") { + cg_args.unroll_threshold = Some(threshold.parse().unwrap_or_else(|_| { + sess.dcx() + .fatal("-unroll-threshold requires a valid integer value") + })); } else { // Do this only after all the other flags above have been tried. match NvvmOption::from_str(arg) { diff --git a/crates/rustc_codegen_nvvm/src/init.rs b/crates/rustc_codegen_nvvm/src/init.rs index 3b13b8a3..f6782b43 100644 --- a/crates/rustc_codegen_nvvm/src/init.rs +++ b/crates/rustc_codegen_nvvm/src/init.rs @@ -107,9 +107,12 @@ unsafe fn configure_llvm(sess: &Session) { // Use non-zero `import-instr-limit` multiplier for cold callsites. add("-import-cold-multiplier=0.1", false); - // for arg in sess_args { - // add(&(*arg), true); - // } + // Forward unroll-threshold if specified in llvm_args + for arg in &sess.opts.cg.llvm_args { + if arg.starts_with("-unroll-threshold=") { + add(arg, true); + } + } } unsafe { llvm::LLVMInitializePasses() };