From ea7d5c1d4b7c14ceaf9879d0ff489b4f0116abc3 Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Sun, 17 Apr 2022 14:52:06 +0200 Subject: [PATCH] rusticl/kernel: prepare for nir caching Signed-off-by: Karol Herbst Acked-by: Alyssa Rosenzweig Part-of: --- src/gallium/frontends/rusticl/api/kernel.rs | 4 -- src/gallium/frontends/rusticl/core/kernel.rs | 70 ++++++++++++------- src/gallium/frontends/rusticl/core/program.rs | 45 ++++++------ 3 files changed, 70 insertions(+), 49 deletions(-) diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index 5240d0b3d9f..7258601f288 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -164,7 +164,6 @@ pub fn create_kernel( // kernel_name such as the number of arguments, the argument types are not the same for all // devices for which the program executable has been built. let devs = get_devices_with_valid_build(&p)?; - let nirs = p.nirs(&name); let kernel_args: HashSet<_> = devs.iter().map(|d| p.args(d, &name)).collect(); if kernel_args.len() != 1 { return Err(CL_INVALID_KERNEL_DEFINITION); @@ -173,7 +172,6 @@ pub fn create_kernel( Ok(cl_kernel::from_arc(Kernel::new( name, p, - nirs, kernel_args.into_iter().next().unwrap(), ))) } @@ -206,13 +204,11 @@ pub fn create_kernels_in_program( if !kernels.is_null() { // we just assume the client isn't stupid unsafe { - let nirs = p.nirs(&name); kernels .add(num_kernels as usize) .write(cl_kernel::from_arc(Kernel::new( name, p.clone(), - nirs, kernel_args.into_iter().next().unwrap(), ))); } diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 8959b543406..cecc0fd9c0e 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -32,7 +32,7 @@ pub enum KernelArgValue { LocalMem(usize), } -#[derive(PartialEq, Eq, Clone)] +#[derive(Hash, PartialEq, Eq, Clone)] pub enum KernelArgType { Constant, // for anything passed by value Image, @@ -53,7 +53,7 @@ pub enum InternalKernelArgType { OrderArray, } -#[derive(Clone)] +#[derive(Hash, PartialEq, Eq, Clone)] pub struct KernelArg { spirv: spirv::SPIRVKernelArg, pub kind: KernelArgType, @@ -70,7 +70,7 @@ pub struct InternalKernelArg { } impl KernelArg { - fn from_spirv_nir(spirv: Vec, nir: &mut NirShader) -> Vec { + fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec { let nir_arg_map: HashMap<_, _> = nir .variables_with_mode( nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image, @@ -79,7 +79,7 @@ impl KernelArg { .collect(); let mut res = Vec::new(); - for (i, s) in spirv.into_iter().enumerate() { + for (i, s) in spirv.iter().enumerate() { let nir = nir_arg_map.get(&(i as i32)).unwrap(); let kind = match s.address_qualifier { clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => { @@ -109,7 +109,7 @@ impl KernelArg { }; res.push(Self { - spirv: s, + spirv: s.clone(), size: unsafe { glsl_get_cl_size(nir.type_) } as usize, // we'll update it later in the 2nd pass kind: kind, @@ -454,6 +454,43 @@ fn lower_and_optimize_nir_late( res } +fn convert_spirv_to_nir( + p: &Program, + name: &str, + args: Vec, +) -> ( + HashMap, NirShader>, + Vec, + Vec, +) { + let mut nirs = HashMap::new(); + let mut args_set = HashSet::new(); + let mut internal_args_set = HashSet::new(); + + // TODO: we could run this in parallel? + for d in p.devs_with_build() { + let mut nir = p.to_nir(name, d); + + lower_and_optimize_nir_pre_inputs(d, &mut nir, &d.lib_clc); + + let mut args = KernelArg::from_spirv_nir(&args, &mut nir); + let mut internal_args = lower_and_optimize_nir_late(d, &mut nir, args.len()); + KernelArg::assign_locations(&mut args, &mut internal_args, &mut nir); + + args_set.insert(args); + internal_args_set.insert(internal_args); + nirs.insert(d.clone(), nir); + } + + // we want the same (internal) args for every compiled kernel, for now + assert!(args_set.len() == 1); + assert!(internal_args_set.len() == 1); + let args = args_set.into_iter().next().unwrap(); + let internal_args = internal_args_set.into_iter().next().unwrap(); + + (nirs, args, internal_args) +} + fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] { let val; (val, *buf) = (*buf).split_at(S); @@ -463,30 +500,15 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] { } impl Kernel { - pub fn new( - name: String, - prog: Arc, - mut nirs: HashMap, NirShader>, - args: Vec, - ) -> Arc { - nirs.iter_mut() - .for_each(|(d, n)| lower_and_optimize_nir_pre_inputs(d, n, &d.lib_clc)); + pub fn new(name: String, prog: Arc, args: Vec) -> Arc { + let (mut nirs, args, internal_args) = convert_spirv_to_nir(&prog, &name, args); + let nir = nirs.values_mut().next().unwrap(); let wgs = nir.workgroup_size(); let work_group_size = [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize]; - let mut args = KernelArg::from_spirv_nir(args, nir); + // can't use vec!... let values = args.iter().map(|_| RefCell::new(None)).collect(); - let internal_args: HashSet<_> = nirs - .iter_mut() - .map(|(d, n)| lower_and_optimize_nir_late(d, n, args.len())) - .collect(); - // we want the same internal args for every compiled kernel, for now - assert!(internal_args.len() == 1); - let mut internal_args = internal_args.into_iter().next().unwrap(); - - nirs.values_mut() - .for_each(|n| KernelArg::assign_locations(&mut args, &mut internal_args, n)); Arc::new(Self { base: CLObjectBase::new(), diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index 1cb1d9a72c0..6c36b754370 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -436,27 +436,30 @@ impl Program { }) } - pub fn nirs(&self, kernel: &str) -> HashMap, NirShader> { + pub fn devs_with_build(&self) -> Vec<&Arc> { let mut lock = self.build_info(); - let mut res = HashMap::new(); - for d in &self.devs { - let info = Self::dev_build_info(&mut lock, d); - if info.status != CL_BUILD_SUCCESS as cl_build_status { - continue; - } - let nir = info - .spirv - .as_ref() - .unwrap() - .to_nir( - kernel, - d.screen - .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE), - &d.lib_clc, - ) - .unwrap(); - res.insert(d.clone(), nir); - } - res + self.devs + .iter() + .filter(|d| { + let info = Self::dev_build_info(&mut lock, d); + info.status == CL_BUILD_SUCCESS as cl_build_status + }) + .collect() + } + + pub fn to_nir(&self, kernel: &str, d: &Arc) -> NirShader { + let mut lock = self.build_info(); + let info = Self::dev_build_info(&mut lock, d); + assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status); + info.spirv + .as_ref() + .unwrap() + .to_nir( + kernel, + d.screen + .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE), + &d.lib_clc, + ) + .unwrap() } }