diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 124978e8acd..49a88db2925 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -1478,7 +1478,13 @@ impl Kernel { } } - fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) { + fn optimize_local_size( + &self, + d: &Device, + work_dim: u32, + grid: &mut [usize; 3], + block: &mut [u32; 3], + ) { if !block.contains(&0) { for i in 0..3 { // we already made sure everything is fine @@ -1492,10 +1498,10 @@ impl Kernel { usize_block[i] = block[i] as usize; } - self.suggest_local_size(d, 3, grid, &mut usize_block); + self.suggest_local_size(d, work_dim as usize, grid, &mut usize_block); for i in 0..3 { - block[i] = usize_block[i] as u32; + block[i] = 1.max(usize_block[i] as u32); } } @@ -1549,7 +1555,7 @@ impl Kernel { let api_grid = grid; - self.optimize_local_size(q.device, &mut grid, &mut block); + self.optimize_local_size(q.device, work_dim, &mut grid, &mut block); Ok(Box::new(move |cl_ctx, ctx| { let hw_max_grid = ctx.dev.max_grid_size();