diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index a191b754476..2bfce7b52fe 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -1244,6 +1244,7 @@ struct KernelExecBuilder<'a> { dev: &'static Device, input: Vec, resource_info: Vec<(&'a PipeResource, usize)>, + workgroup_id_offset_loc: Option, } impl<'a> KernelExecBuilder<'a> { @@ -1252,6 +1253,7 @@ impl<'a> KernelExecBuilder<'a> { dev: dev, input: Vec::new(), resource_info: Vec::new(), + workgroup_id_offset_loc: None, } } @@ -1304,6 +1306,27 @@ impl<'a> KernelExecBuilder<'a> { (resources, globals) } + + /// Marks the current position in the kernel input buffer as the location of the workgroup id + /// offsets for use with `set_workgroup_id_offset`. + fn mark_workgroup_id_offset(&mut self) { + self.workgroup_id_offset_loc = Some(self.input.len()); + } + + /// Sets the workgroup id offsets within the kernel input buffer to the provided values. + fn set_workgroup_id_offset(&mut self, offset: [usize; 3]) { + if let Some(workgroup_id_offset_loc) = self.workgroup_id_offset_loc { + if self.dev.address_bits() == 64 { + let val = offset.map(|v| v as u64); + self.input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24] + .copy_from_slice(unsafe { as_byte_slice(&val) }); + } else { + let val = offset.map(|v| v as u32); + self.input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12] + .copy_from_slice(unsafe { as_byte_slice(&val) }); + } + } + } } impl Kernel { @@ -1452,7 +1475,6 @@ impl Kernel { }; let nir_kernel_build = &nir_kernel_builds[variant]; - let mut workgroup_id_offset_loc = None; let mut exec_builder = KernelExecBuilder::new(ctx.dev); // Set it once so we get the alignment padding right let static_local_size: u64 = nir_kernel_build.shared_size; @@ -1619,7 +1641,7 @@ impl Kernel { exec_builder.add_sysval(&offsets); } CompiledKernelArgType::WorkGroupOffsets => { - workgroup_id_offset_loc = Some(exec_builder.input.len()); + exec_builder.mark_workgroup_id_offset(); exec_builder.add_values(null_ptr_v3); } CompiledKernelArgType::GlobalWorkSize => { @@ -1688,22 +1710,11 @@ impl Kernel { for z in 0..grid[2].div_ceil(hw_max_grid[2]) { for y in 0..grid[1].div_ceil(hw_max_grid[1]) { for x in 0..grid[0].div_ceil(hw_max_grid[0]) { - if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc { - let this_offsets = - [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]]; + let this_offsets = + [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]]; - if ctx.dev.address_bits() == 64 { - let val = this_offsets.map(|v| v as u64); - exec_builder.input - [workgroup_id_offset_loc..workgroup_id_offset_loc + 24] - .copy_from_slice(unsafe { as_byte_slice(&val) }); - } else { - let val = this_offsets.map(|v| v as u32); - exec_builder.input - [workgroup_id_offset_loc..workgroup_id_offset_loc + 12] - .copy_from_slice(unsafe { as_byte_slice(&val) }); - } - } + // Each iteration we need to update the kernel side workgroup id offsets. + exec_builder.set_workgroup_id_offset(this_offsets); let this_grid = [ cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,