diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 87b7fde283d..7636231510c 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -411,21 +411,18 @@ fn lower_and_optimize_nir( lib_clc: &NirShader, ) -> (Vec, Vec) { let address_bits_ptr_type; + let address_bits_base_type; let global_address_format; let shared_address_format; - let host_bits_base_type = if size_of::() == 8 { - glsl_base_type::GLSL_TYPE_UINT64 - } else { - glsl_base_type::GLSL_TYPE_UINT - }; - if dev.address_bits() == 64 { address_bits_ptr_type = unsafe { glsl_uint64_t_type() }; + address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64; global_address_format = nir_address_format::nir_address_format_64bit_global; shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit; } else { address_bits_ptr_type = unsafe { glsl_uint_type() }; + address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT; global_address_format = nir_address_format::nir_address_format_32bit_global; shared_address_format = nir_address_format::nir_address_format_32bit_offset; } @@ -553,7 +550,7 @@ fn lower_and_optimize_nir( lower_state.base_global_invoc_id_loc = args.len() + internal_args.len() - 1; nir.add_var( nir_variable_mode::nir_var_uniform, - unsafe { glsl_vector_type(host_bits_base_type, 3) }, + unsafe { glsl_vector_type(address_bits_base_type, 3) }, lower_state.base_global_invoc_id_loc, "base_global_invocation_id", ); @@ -568,7 +565,7 @@ fn lower_and_optimize_nir( lower_state.base_workgroup_id_loc = args.len() + internal_args.len() - 1; nir.add_var( nir_variable_mode::nir_var_uniform, - unsafe { glsl_vector_type(host_bits_base_type, 3) }, + unsafe { glsl_vector_type(address_bits_base_type, 3) }, lower_state.base_workgroup_id_loc, "base_workgroup_id", ); @@ -962,11 +959,14 @@ impl Kernel { let mut img_formats: Vec = Vec::new(); let mut img_orders: Vec = Vec::new(); - let host_null_v3 = &[0u8; 3 * size_of::()]; - let null_ptr = if q.device.address_bits() == 64 { - [0u8; 8].as_slice() + let null_ptr; + let null_ptr_v3; + if q.device.address_bits() == 64 { + null_ptr = [0u8; 8].as_slice(); + null_ptr_v3 = [0u8; 24].as_slice(); } else { - [0u8; 4].as_slice() + null_ptr = [0u8; 4].as_slice(); + null_ptr_v3 = [0u8; 12].as_slice(); }; self.optimize_local_size(q.device, &mut grid, &mut block); @@ -989,9 +989,11 @@ impl Kernel { KernelArgValue::Buffer(buffer) => { let res = buffer.get_res_of_dev(q.device)?; if q.device.address_bits() == 64 { - input.extend_from_slice(&buffer.offset.to_ne_bytes()); + let offset: u64 = buffer.offset as u64; + input.extend_from_slice(&offset.to_ne_bytes()); } else { - input.extend_from_slice(&(buffer.offset as u32).to_ne_bytes()); + let offset: u32 = buffer.offset as u32; + input.extend_from_slice(&offset.to_ne_bytes()); } resource_info.push((res.clone(), arg.offset)); } @@ -1048,9 +1050,12 @@ impl Kernel { variable_local_size = align(variable_local_size, pot.next_power_of_two() as u64); if q.device.address_bits() == 64 { - input.extend_from_slice(&variable_local_size.to_ne_bytes()); + let variable_local_size: [u8; 8] = variable_local_size.to_ne_bytes(); + input.extend_from_slice(&variable_local_size); } else { - input.extend_from_slice(&(variable_local_size as u32).to_ne_bytes()); + let variable_local_size: [u8; 4] = + (variable_local_size as u32).to_ne_bytes(); + input.extend_from_slice(&variable_local_size); } variable_local_size += *size as u64; } @@ -1085,11 +1090,27 @@ impl Kernel { )); } InternalKernelArgType::GlobalWorkOffsets => { - input.extend_from_slice(unsafe { as_byte_slice(&offsets) }); + if q.device.address_bits() == 64 { + input.extend_from_slice(unsafe { + as_byte_slice(&[ + offsets[0] as u64, + offsets[1] as u64, + offsets[2] as u64, + ]) + }); + } else { + input.extend_from_slice(unsafe { + as_byte_slice(&[ + offsets[0] as u32, + offsets[1] as u32, + offsets[2] as u32, + ]) + }); + } } InternalKernelArgType::WorkGroupOffsets => { workgroup_id_offset_loc = Some(input.len()); - input.extend_from_slice(host_null_v3); + input.extend_from_slice(null_ptr_v3); } InternalKernelArgType::PrintfBuffer => { let buf = Arc::new( @@ -1192,9 +1213,15 @@ impl Kernel { let this_offsets = [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]]; - input[workgroup_id_offset_loc - ..workgroup_id_offset_loc + (size_of::() * 3)] - .copy_from_slice(unsafe { as_byte_slice(&this_offsets) }); + if q.device.address_bits() == 64 { + let val = this_offsets.map(|v| v as u64); + input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24] + .copy_from_slice(unsafe { as_byte_slice(&val) }); + } else { + let val = this_offsets.map(|v| v as u32); + input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12] + .copy_from_slice(unsafe { as_byte_slice(&val) }); + } } let this_grid = [