From d02dfe0f71dfe4ca22aa209c9802e533cd56360d Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Tue, 28 May 2024 18:25:06 +0200 Subject: [PATCH] rusticl/kernel/launch: fix mapping usize types to GPU pointer sizes I incorrectly assumed the API side defines how those values are sized, but it's actually the GPU's pointer size. The API is simply reduced to 32 bit ranges in 32 bit mode, but has to still pass in 64 bit values to the GPU. Also use explicit types in a couple of places to prevent such mistakes in the future. Fixes: 204c287327f ("rusticl/kernel: properly handle grid and offsets being usize") Part-of: --- src/gallium/frontends/rusticl/core/kernel.rs | 69 ++++++++++++++------ 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 87b7fde283d..7636231510c 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -411,21 +411,18 @@ fn lower_and_optimize_nir( lib_clc: &NirShader, ) -> (Vec, Vec) { let address_bits_ptr_type; + let address_bits_base_type; let global_address_format; let shared_address_format; - let host_bits_base_type = if size_of::() == 8 { - glsl_base_type::GLSL_TYPE_UINT64 - } else { - glsl_base_type::GLSL_TYPE_UINT - }; - if dev.address_bits() == 64 { address_bits_ptr_type = unsafe { glsl_uint64_t_type() }; + address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64; global_address_format = nir_address_format::nir_address_format_64bit_global; shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit; } else { address_bits_ptr_type = unsafe { glsl_uint_type() }; + address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT; global_address_format = nir_address_format::nir_address_format_32bit_global; shared_address_format = nir_address_format::nir_address_format_32bit_offset; } @@ -553,7 +550,7 @@ fn lower_and_optimize_nir( lower_state.base_global_invoc_id_loc = args.len() + internal_args.len() - 1; nir.add_var( nir_variable_mode::nir_var_uniform, - unsafe { glsl_vector_type(host_bits_base_type, 3) }, + unsafe { glsl_vector_type(address_bits_base_type, 3) }, lower_state.base_global_invoc_id_loc, "base_global_invocation_id", ); @@ -568,7 +565,7 @@ fn lower_and_optimize_nir( lower_state.base_workgroup_id_loc = args.len() + internal_args.len() - 1; nir.add_var( nir_variable_mode::nir_var_uniform, - unsafe { glsl_vector_type(host_bits_base_type, 3) }, + unsafe { glsl_vector_type(address_bits_base_type, 3) }, lower_state.base_workgroup_id_loc, "base_workgroup_id", ); @@ -962,11 +959,14 @@ impl Kernel { let mut img_formats: Vec = Vec::new(); let mut img_orders: Vec = Vec::new(); - let host_null_v3 = &[0u8; 3 * size_of::()]; - let null_ptr = if q.device.address_bits() == 64 { - [0u8; 8].as_slice() + let null_ptr; + let null_ptr_v3; + if q.device.address_bits() == 64 { + null_ptr = [0u8; 8].as_slice(); + null_ptr_v3 = [0u8; 24].as_slice(); } else { - [0u8; 4].as_slice() + null_ptr = [0u8; 4].as_slice(); + null_ptr_v3 = [0u8; 12].as_slice(); }; self.optimize_local_size(q.device, &mut grid, &mut block); @@ -989,9 +989,11 @@ impl Kernel { KernelArgValue::Buffer(buffer) => { let res = buffer.get_res_of_dev(q.device)?; if q.device.address_bits() == 64 { - input.extend_from_slice(&buffer.offset.to_ne_bytes()); + let offset: u64 = buffer.offset as u64; + input.extend_from_slice(&offset.to_ne_bytes()); } else { - input.extend_from_slice(&(buffer.offset as u32).to_ne_bytes()); + let offset: u32 = buffer.offset as u32; + input.extend_from_slice(&offset.to_ne_bytes()); } resource_info.push((res.clone(), arg.offset)); } @@ -1048,9 +1050,12 @@ impl Kernel { variable_local_size = align(variable_local_size, pot.next_power_of_two() as u64); if q.device.address_bits() == 64 { - input.extend_from_slice(&variable_local_size.to_ne_bytes()); + let variable_local_size: [u8; 8] = variable_local_size.to_ne_bytes(); + input.extend_from_slice(&variable_local_size); } else { - input.extend_from_slice(&(variable_local_size as u32).to_ne_bytes()); + let variable_local_size: [u8; 4] = + (variable_local_size as u32).to_ne_bytes(); + input.extend_from_slice(&variable_local_size); } variable_local_size += *size as u64; } @@ -1085,11 +1090,27 @@ impl Kernel { )); } InternalKernelArgType::GlobalWorkOffsets => { - input.extend_from_slice(unsafe { as_byte_slice(&offsets) }); + if q.device.address_bits() == 64 { + input.extend_from_slice(unsafe { + as_byte_slice(&[ + offsets[0] as u64, + offsets[1] as u64, + offsets[2] as u64, + ]) + }); + } else { + input.extend_from_slice(unsafe { + as_byte_slice(&[ + offsets[0] as u32, + offsets[1] as u32, + offsets[2] as u32, + ]) + }); + } } InternalKernelArgType::WorkGroupOffsets => { workgroup_id_offset_loc = Some(input.len()); - input.extend_from_slice(host_null_v3); + input.extend_from_slice(null_ptr_v3); } InternalKernelArgType::PrintfBuffer => { let buf = Arc::new( @@ -1192,9 +1213,15 @@ impl Kernel { let this_offsets = [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]]; - input[workgroup_id_offset_loc - ..workgroup_id_offset_loc + (size_of::() * 3)] - .copy_from_slice(unsafe { as_byte_slice(&this_offsets) }); + if q.device.address_bits() == 64 { + let val = this_offsets.map(|v| v as u64); + input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24] + .copy_from_slice(unsafe { as_byte_slice(&val) }); + } else { + let val = this_offsets.map(|v| v as u32); + input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12] + .copy_from_slice(unsafe { as_byte_slice(&val) }); + } } let this_grid = [