diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index aff4152435..7acca2a39b 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -9,7 +9,7 @@ use crate::builder_spirv::{SpirvFunctionCursor, SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; use rspirv::dr::Operand; use rspirv::spirv::{ - Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word, + BuiltIn, Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word, }; use rustc_abi::FieldsShape; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _}; @@ -916,6 +916,11 @@ impl<'tcx> CodegenCx<'tcx> { ); } + // Check builtin-specific type requirements. + if let Some(builtin) = attrs.builtin { + self.check_builtin_type(hir_param.ty_span, value_layout.ty, builtin); + } + if let Ok(storage_class) = storage_class { self.check_for_bad_types( execution_model, @@ -1083,4 +1088,15 @@ impl<'tcx> CodegenCx<'tcx> { } } } + + /// Check that builtin variables have the correct type. + fn check_builtin_type(&self, span: Span, rust_ty: Ty<'tcx>, builtin: Spanned) { + // LocalInvocationIndex must be a u32. + if builtin.value == BuiltIn::LocalInvocationIndex && rust_ty != self.tcx.types.u32 { + self.tcx.dcx().span_err( + span, + format!("`#[spirv(local_invocation_index)]` must be a `u32`, not `{rust_ty}`"), + ); + } + } } diff --git a/tests/compiletests/ui/arch/shared/dce_shared.rs b/tests/compiletests/ui/arch/shared/dce_shared.rs index 2c639ae5fc..f06aa6bbcc 100644 --- a/tests/compiletests/ui/arch/shared/dce_shared.rs +++ b/tests/compiletests/ui/arch/shared/dce_shared.rs @@ -21,10 +21,10 @@ pub fn main( #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut f32, #[spirv(workgroup)] used_shared: &mut f32, #[spirv(workgroup)] dce_shared: &mut [i32; 2], - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, ) { unsafe { - let inv_id = inv_id.x as usize; + let inv_id = inv_id as usize; if inv_id == 0 { *used_shared = *input; } diff --git a/tests/compiletests/ui/arch/shared/dce_shared.stderr b/tests/compiletests/ui/arch/shared/dce_shared.stderr index 532063d2ac..681d038026 100644 --- a/tests/compiletests/ui/arch/shared/dce_shared.stderr +++ b/tests/compiletests/ui/arch/shared/dce_shared.stderr @@ -20,16 +20,15 @@ OpDecorate %4 BuiltIn LocalInvocationIndex %11 = OpTypePointer Workgroup %9 %12 = OpTypeInt 32 0 %13 = OpConstant %12 2 -%14 = OpTypeVector %12 3 -%15 = OpTypePointer Input %14 -%16 = OpTypeVoid -%17 = OpTypeFunction %16 -%18 = OpTypePointer StorageBuffer %9 +%14 = OpTypePointer Input %12 +%15 = OpTypeVoid +%16 = OpTypeFunction %15 +%17 = OpTypePointer StorageBuffer %9 %2 = OpVariable %10 StorageBuffer -%19 = OpConstant %12 0 +%18 = OpConstant %12 0 %3 = OpVariable %10 StorageBuffer -%4 = OpVariable %15 Input -%20 = OpTypeBool +%4 = OpVariable %14 Input +%19 = OpTypeBool %5 = OpVariable %11 Workgroup -%21 = OpConstant %12 264 -%22 = OpConstant %12 1 +%20 = OpConstant %12 264 +%21 = OpConstant %12 1 diff --git a/tests/compiletests/ui/arch/shared/reduction_array.rs b/tests/compiletests/ui/arch/shared/reduction_array.rs index 5ccd4136c5..a9a4238abb 100644 --- a/tests/compiletests/ui/arch/shared/reduction_array.rs +++ b/tests/compiletests/ui/arch/shared/reduction_array.rs @@ -57,10 +57,10 @@ pub fn main( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value], #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value, #[spirv(workgroup)] shared: &mut [Value; WG_SIZE], - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, ) { unsafe { - let inv_id = inv_id.x as usize; + let inv_id = inv_id as usize; shared[inv_id] = input[inv_id]; workgroup_memory_barrier_with_group_sync(); diff --git a/tests/compiletests/ui/arch/shared/reduction_big_struct.rs b/tests/compiletests/ui/arch/shared/reduction_big_struct.rs index 809508ca68..5e651cfb9d 100644 --- a/tests/compiletests/ui/arch/shared/reduction_big_struct.rs +++ b/tests/compiletests/ui/arch/shared/reduction_big_struct.rs @@ -65,10 +65,10 @@ pub fn main( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value], #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value, #[spirv(workgroup)] shared: &mut [Value; WG_SIZE], - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, ) { unsafe { - let inv_id = inv_id.x as usize; + let inv_id = inv_id as usize; shared[inv_id] = input[inv_id]; workgroup_memory_barrier_with_group_sync(); diff --git a/tests/compiletests/ui/arch/shared/reduction_u32.rs b/tests/compiletests/ui/arch/shared/reduction_u32.rs index 0f60e495df..7e114bb093 100644 --- a/tests/compiletests/ui/arch/shared/reduction_u32.rs +++ b/tests/compiletests/ui/arch/shared/reduction_u32.rs @@ -22,10 +22,10 @@ pub fn main( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value], #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value, #[spirv(workgroup)] shared: &mut [Value; WG_SIZE], - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, ) { unsafe { - let inv_id = inv_id.x as usize; + let inv_id = inv_id as usize; shared[inv_id] = input[inv_id]; workgroup_memory_barrier_with_group_sync(); diff --git a/tests/compiletests/ui/arch/shared/reduction_vec.rs b/tests/compiletests/ui/arch/shared/reduction_vec.rs index 7fe18d1af3..10f851268d 100644 --- a/tests/compiletests/ui/arch/shared/reduction_vec.rs +++ b/tests/compiletests/ui/arch/shared/reduction_vec.rs @@ -22,10 +22,10 @@ pub fn main( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] input: &[Value], #[spirv(descriptor_set = 0, binding = 1, storage_buffer)] output: &mut Value, #[spirv(workgroup)] shared: &mut [Value; WG_SIZE], - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, ) { unsafe { - let inv_id = inv_id.x as usize; + let inv_id = inv_id as usize; shared[inv_id] = input[inv_id]; workgroup_memory_barrier_with_group_sync(); diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs index 5db5ba963e..9dd1ae0827 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs @@ -30,7 +30,7 @@ pub struct Zst; #[spirv(compute(threads(32)))] pub fn main( - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_id)] inv_id: UVec3, #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut UVec3, ) { unsafe { diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs index 953a4429e8..27c84ffaaf 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs @@ -29,14 +29,15 @@ fn disassembly(my_struct: MyStruct) -> bool { #[spirv(compute(threads(32)))] pub fn main( - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, + #[spirv(local_invocation_id)] inv_id_3d: UVec3, #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut u32, ) { unsafe { let my_struct = MyStruct { - a: inv_id.x as f32, - b: inv_id, - c: Nested(5i32 - inv_id.x as i32), + a: inv_id as f32, + b: inv_id_3d, + c: Nested(5i32 - inv_id as i32), d: Zst, }; diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs index ada500fb06..113e0f40fa 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs @@ -43,11 +43,11 @@ fn disassembly(my_struct: MyEnum, id: u32) -> MyEnum { #[spirv(compute(threads(32)))] pub fn main( - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyEnum, ) { unsafe { - let my_enum = MyEnum::from(inv_id.x % 3); + let my_enum = MyEnum::from(inv_id % 3); *output = disassembly(my_enum, 5); } } diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs index a40c26175a..6315697c4b 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs @@ -29,14 +29,15 @@ fn disassembly(my_struct: MyStruct, id: u32) -> MyStruct { #[spirv(compute(threads(32)))] pub fn main( - #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(local_invocation_index)] inv_id: u32, + #[spirv(local_invocation_id)] inv_id_3d: UVec3, #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyStruct, ) { unsafe { let my_struct = MyStruct { - a: inv_id.x as f32, - b: inv_id, - c: Nested(5i32 - inv_id.x as i32), + a: inv_id as f32, + b: inv_id_3d, + c: Nested(5i32 - inv_id as i32), d: Zst, }; diff --git a/tests/compiletests/ui/spirv-attr/all-builtins.rs b/tests/compiletests/ui/spirv-attr/all-builtins.rs index 396fd850f4..4f5475d996 100644 --- a/tests/compiletests/ui/spirv-attr/all-builtins.rs +++ b/tests/compiletests/ui/spirv-attr/all-builtins.rs @@ -44,7 +44,7 @@ pub fn vertex( #[spirv(frag_stencil_ref_ext)] frag_stencil_ref_ext: &mut u32, #[spirv(instance_index)] instance_index: u32, #[spirv(layer_per_view_nv)] layer_per_view_nv: u32, - #[spirv(local_invocation_index)] local_invocation_index: UVec3, + #[spirv(local_invocation_index)] local_invocation_index: u32, #[spirv(mesh_view_count_nv)] mesh_view_count_nv: u32, #[spirv(mesh_view_indices_nv)] mesh_view_indices_nv: u32, #[spirv(point_size)] point_size: &mut u32, diff --git a/tests/compiletests/ui/spirv-attr/local-invocation-index-type.rs b/tests/compiletests/ui/spirv-attr/local-invocation-index-type.rs new file mode 100644 index 0000000000..c07ff05bff --- /dev/null +++ b/tests/compiletests/ui/spirv-attr/local-invocation-index-type.rs @@ -0,0 +1,7 @@ +// build-fail + +use spirv_std::glam::UVec3; +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn main(#[spirv(local_invocation_index)] index: UVec3) {} diff --git a/tests/compiletests/ui/spirv-attr/local-invocation-index-type.stderr b/tests/compiletests/ui/spirv-attr/local-invocation-index-type.stderr new file mode 100644 index 0000000000..60469cd8fb --- /dev/null +++ b/tests/compiletests/ui/spirv-attr/local-invocation-index-type.stderr @@ -0,0 +1,8 @@ +error: `#[spirv(local_invocation_index)]` must be a `u32`, not `spirv_std::glam::UVec3` + --> $DIR/local-invocation-index-type.rs:7:53 + | +LL | pub fn main(#[spirv(local_invocation_index)] index: UVec3) {} + | ^^^^^ + +error: aborting due to 1 previous error +