diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 8f745556bb6..d92956385dd 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -1239,6 +1239,31 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] { val.try_into().unwrap() } +/// Helper class to build an execution environment for a single kernel invocation. +struct KernelExecBuilder { + dev: &'static Device, + input: Vec, +} + +impl KernelExecBuilder { + fn new(dev: &'static Device) -> Self { + Self { + dev: dev, + input: Vec::new(), + } + } + + fn add_pointer(&mut self, address: u64) { + if self.dev.address_bits() == 64 { + let address: u64 = address; + self.input.extend_from_slice(&address.to_ne_bytes()); + } else { + let address: u32 = address as u32; + self.input.extend_from_slice(&address.to_ne_bytes()); + } + } +} + impl Kernel { pub fn new(name: String, prog: Arc, prog_build: &ProgramBuild) -> Arc { let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap()); @@ -1386,7 +1411,7 @@ impl Kernel { let nir_kernel_build = &nir_kernel_builds[variant]; let mut workgroup_id_offset_loc = None; - let mut input = Vec::new(); + let mut exec_builder = KernelExecBuilder::new(ctx.dev); // Set it once so we get the alignment padding right let static_local_size: u64 = nir_kernel_build.shared_size; let mut variable_local_size: u64 = static_local_size; @@ -1410,25 +1435,15 @@ impl Kernel { }; let mut resource_info = Vec::new(); - fn add_pointer(ctx: &QueueContext, input: &mut Vec, address: u64) { - if ctx.dev.address_bits() == 64 { - let address: u64 = address; - input.extend_from_slice(&address.to_ne_bytes()); - } else { - let address: u32 = address as u32; - input.extend_from_slice(&address.to_ne_bytes()); - } - } fn add_global<'a>( - ctx: &QueueContext, - input: &mut Vec, + exec_builder: &mut KernelExecBuilder, resource_info: &mut Vec<(&'a PipeResource, usize)>, res: &'a PipeResourceOwned, offset: usize, ) { - resource_info.push((res.borrow(), input.len())); - add_pointer(ctx, input, offset as u64); + resource_info.push((res.borrow(), exec_builder.input.len())); + exec_builder.add_pointer(offset as u64); } fn add_sysval(ctx: &QueueContext, input: &mut Vec, vals: &[usize; 3]) { @@ -1466,8 +1481,8 @@ impl Kernel { false }; - if !is_opaque && arg.offset > input.len() { - input.resize(arg.offset, 0); + if !is_opaque && arg.offset > exec_builder.input.len() { + exec_builder.input.resize(arg.offset, 0); } match arg.kind { @@ -1478,11 +1493,11 @@ impl Kernel { }; match value { - KernelArgValue::Constant(c) => input.extend_from_slice(c), + KernelArgValue::Constant(c) => exec_builder.input.extend_from_slice(c), KernelArgValue::BDA(address) => { bdas.push(*address); if !api_arg.dead { - add_pointer(ctx, &mut input, *address); + exec_builder.add_pointer(*address); } } KernelArgValue::Buffer(buffer) => { @@ -1509,8 +1524,7 @@ impl Kernel { } else { let res = buffer.get_res_for_access(ctx, rw)?; add_global( - ctx, - &mut input, + &mut exec_builder, &mut resource_info, res, buffer.offset(), @@ -1524,7 +1538,7 @@ impl Kernel { } if !api_arg.dead { - add_pointer(ctx, &mut input, handle as u64); + exec_builder.add_pointer(handle as u64); } } KernelArgValue::Image(image) => { @@ -1556,11 +1570,11 @@ impl Kernel { if ctx.dev.address_bits() == 64 { let variable_local_size: [u8; 8] = variable_local_size.to_ne_bytes(); - input.extend_from_slice(&variable_local_size); + exec_builder.input.extend_from_slice(&variable_local_size); } else { let variable_local_size: [u8; 4] = (variable_local_size as u32).to_ne_bytes(); - input.extend_from_slice(&variable_local_size); + exec_builder.input.extend_from_slice(&variable_local_size); } variable_local_size += *size as u64; } @@ -1574,7 +1588,7 @@ impl Kernel { KernelArgType::MemGlobal | KernelArgType::MemConstant ) { - input.extend_from_slice(null_ptr); + exec_builder.input.extend_from_slice(null_ptr); } } } @@ -1582,38 +1596,46 @@ impl Kernel { CompiledKernelArgType::ConstantBuffer => { assert!(nir_kernel_build.constant_buffer.is_some()); let res = nir_kernel_build.constant_buffer.as_ref().unwrap(); - add_global(ctx, &mut input, &mut resource_info, res, 0); + add_global(&mut exec_builder, &mut resource_info, res, 0); } CompiledKernelArgType::GlobalWorkOffsets => { - add_sysval(ctx, &mut input, &offsets); + add_sysval(ctx, &mut exec_builder.input, &offsets); } CompiledKernelArgType::WorkGroupOffsets => { - workgroup_id_offset_loc = Some(input.len()); - input.extend_from_slice(null_ptr_v3); + workgroup_id_offset_loc = Some(exec_builder.input.len()); + exec_builder.input.extend_from_slice(null_ptr_v3); } CompiledKernelArgType::GlobalWorkSize => { - add_sysval(ctx, &mut input, &api_grid); + add_sysval(ctx, &mut exec_builder.input, &api_grid); } CompiledKernelArgType::PrintfBuffer => { let res = printf_buf.as_ref().unwrap(); - add_global(ctx, &mut input, &mut resource_info, res, 0); + add_global(&mut exec_builder, &mut resource_info, res, 0); } CompiledKernelArgType::InlineSampler(cl) => { samplers.push(Sampler::cl_to_pipe(cl)); } CompiledKernelArgType::FormatArray => { - input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) }); - input.extend_from_slice(unsafe { as_byte_slice(&img_formats) }); + exec_builder + .input + .extend_from_slice(unsafe { as_byte_slice(&tex_formats) }); + exec_builder + .input + .extend_from_slice(unsafe { as_byte_slice(&img_formats) }); } CompiledKernelArgType::OrderArray => { - input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) }); - input.extend_from_slice(unsafe { as_byte_slice(&img_orders) }); + exec_builder + .input + .extend_from_slice(unsafe { as_byte_slice(&tex_orders) }); + exec_builder + .input + .extend_from_slice(unsafe { as_byte_slice(&img_orders) }); } CompiledKernelArgType::WorkDim => { - input.extend_from_slice(&[work_dim as u8; 1]); + exec_builder.input.extend_from_slice(&[work_dim as u8; 1]); } CompiledKernelArgType::NumWorkgroups => { - input.extend_from_slice(unsafe { + exec_builder.input.extend_from_slice(unsafe { as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32]) }); } @@ -1650,7 +1672,7 @@ impl Kernel { let mut globals: Vec<*mut u32> = Vec::with_capacity(resource_info.len()); for (res, offset) in resource_info { resources.push(res); - globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast()); + globals.push(unsafe { exec_builder.input.as_mut_ptr().byte_add(offset) }.cast()); } ctx.bind_kernel(&nir_kernel_builds, variant)?; @@ -1668,11 +1690,13 @@ impl Kernel { if ctx.dev.address_bits() == 64 { let val = this_offsets.map(|v| v as u64); - input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24] + exec_builder.input + [workgroup_id_offset_loc..workgroup_id_offset_loc + 24] .copy_from_slice(unsafe { as_byte_slice(&val) }); } else { let val = this_offsets.map(|v| v as u32); - input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12] + exec_builder.input + [workgroup_id_offset_loc..workgroup_id_offset_loc + 12] .copy_from_slice(unsafe { as_byte_slice(&val) }); } } @@ -1683,7 +1707,7 @@ impl Kernel { cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32, ]; - ctx.update_cb0(&input)?; + ctx.update_cb0(&exec_builder.input)?; ctx.launch_grid( work_dim, block,