rusticl/kernel: prepare for nir caching

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15439>
This commit is contained in:
Karol Herbst
2022-04-17 14:52:06 +02:00
committed by Marge Bot
parent 0da5e8704b
commit ea7d5c1d4b
3 changed files with 70 additions and 49 deletions
@@ -164,7 +164,6 @@ pub fn create_kernel(
// kernel_name such as the number of arguments, the argument types are not the same for all
// devices for which the program executable has been built.
let devs = get_devices_with_valid_build(&p)?;
let nirs = p.nirs(&name);
let kernel_args: HashSet<_> = devs.iter().map(|d| p.args(d, &name)).collect();
if kernel_args.len() != 1 {
return Err(CL_INVALID_KERNEL_DEFINITION);
@@ -173,7 +172,6 @@ pub fn create_kernel(
Ok(cl_kernel::from_arc(Kernel::new(
name,
p,
nirs,
kernel_args.into_iter().next().unwrap(),
)))
}
@@ -206,13 +204,11 @@ pub fn create_kernels_in_program(
if !kernels.is_null() {
// we just assume the client isn't stupid
unsafe {
let nirs = p.nirs(&name);
kernels
.add(num_kernels as usize)
.write(cl_kernel::from_arc(Kernel::new(
name,
p.clone(),
nirs,
kernel_args.into_iter().next().unwrap(),
)));
}
+46 -24
View File
@@ -32,7 +32,7 @@ pub enum KernelArgValue {
LocalMem(usize),
}
#[derive(PartialEq, Eq, Clone)]
#[derive(Hash, PartialEq, Eq, Clone)]
pub enum KernelArgType {
Constant, // for anything passed by value
Image,
@@ -53,7 +53,7 @@ pub enum InternalKernelArgType {
OrderArray,
}
#[derive(Clone)]
#[derive(Hash, PartialEq, Eq, Clone)]
pub struct KernelArg {
spirv: spirv::SPIRVKernelArg,
pub kind: KernelArgType,
@@ -70,7 +70,7 @@ pub struct InternalKernelArg {
}
impl KernelArg {
fn from_spirv_nir(spirv: Vec<spirv::SPIRVKernelArg>, nir: &mut NirShader) -> Vec<Self> {
fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
let nir_arg_map: HashMap<_, _> = nir
.variables_with_mode(
nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
@@ -79,7 +79,7 @@ impl KernelArg {
.collect();
let mut res = Vec::new();
for (i, s) in spirv.into_iter().enumerate() {
for (i, s) in spirv.iter().enumerate() {
let nir = nir_arg_map.get(&(i as i32)).unwrap();
let kind = match s.address_qualifier {
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
@@ -109,7 +109,7 @@ impl KernelArg {
};
res.push(Self {
spirv: s,
spirv: s.clone(),
size: unsafe { glsl_get_cl_size(nir.type_) } as usize,
// we'll update it later in the 2nd pass
kind: kind,
@@ -454,6 +454,43 @@ fn lower_and_optimize_nir_late(
res
}
fn convert_spirv_to_nir(
p: &Program,
name: &str,
args: Vec<spirv::SPIRVKernelArg>,
) -> (
HashMap<Arc<Device>, NirShader>,
Vec<KernelArg>,
Vec<InternalKernelArg>,
) {
let mut nirs = HashMap::new();
let mut args_set = HashSet::new();
let mut internal_args_set = HashSet::new();
// TODO: we could run this in parallel?
for d in p.devs_with_build() {
let mut nir = p.to_nir(name, d);
lower_and_optimize_nir_pre_inputs(d, &mut nir, &d.lib_clc);
let mut args = KernelArg::from_spirv_nir(&args, &mut nir);
let mut internal_args = lower_and_optimize_nir_late(d, &mut nir, args.len());
KernelArg::assign_locations(&mut args, &mut internal_args, &mut nir);
args_set.insert(args);
internal_args_set.insert(internal_args);
nirs.insert(d.clone(), nir);
}
// we want the same (internal) args for every compiled kernel, for now
assert!(args_set.len() == 1);
assert!(internal_args_set.len() == 1);
let args = args_set.into_iter().next().unwrap();
let internal_args = internal_args_set.into_iter().next().unwrap();
(nirs, args, internal_args)
}
fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
let val;
(val, *buf) = (*buf).split_at(S);
@@ -463,30 +500,15 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
}
impl Kernel {
pub fn new(
name: String,
prog: Arc<Program>,
mut nirs: HashMap<Arc<Device>, NirShader>,
args: Vec<spirv::SPIRVKernelArg>,
) -> Arc<Kernel> {
nirs.iter_mut()
.for_each(|(d, n)| lower_and_optimize_nir_pre_inputs(d, n, &d.lib_clc));
pub fn new(name: String, prog: Arc<Program>, args: Vec<spirv::SPIRVKernelArg>) -> Arc<Kernel> {
let (mut nirs, args, internal_args) = convert_spirv_to_nir(&prog, &name, args);
let nir = nirs.values_mut().next().unwrap();
let wgs = nir.workgroup_size();
let work_group_size = [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize];
let mut args = KernelArg::from_spirv_nir(args, nir);
// can't use vec!...
let values = args.iter().map(|_| RefCell::new(None)).collect();
let internal_args: HashSet<_> = nirs
.iter_mut()
.map(|(d, n)| lower_and_optimize_nir_late(d, n, args.len()))
.collect();
// we want the same internal args for every compiled kernel, for now
assert!(internal_args.len() == 1);
let mut internal_args = internal_args.into_iter().next().unwrap();
nirs.values_mut()
.for_each(|n| KernelArg::assign_locations(&mut args, &mut internal_args, n));
Arc::new(Self {
base: CLObjectBase::new(),
+24 -21
View File
@@ -436,27 +436,30 @@ impl Program {
})
}
pub fn nirs(&self, kernel: &str) -> HashMap<Arc<Device>, NirShader> {
pub fn devs_with_build(&self) -> Vec<&Arc<Device>> {
let mut lock = self.build_info();
let mut res = HashMap::new();
for d in &self.devs {
let info = Self::dev_build_info(&mut lock, d);
if info.status != CL_BUILD_SUCCESS as cl_build_status {
continue;
}
let nir = info
.spirv
.as_ref()
.unwrap()
.to_nir(
kernel,
d.screen
.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
&d.lib_clc,
)
.unwrap();
res.insert(d.clone(), nir);
}
res
self.devs
.iter()
.filter(|d| {
let info = Self::dev_build_info(&mut lock, d);
info.status == CL_BUILD_SUCCESS as cl_build_status
})
.collect()
}
pub fn to_nir(&self, kernel: &str, d: &Arc<Device>) -> NirShader {
let mut lock = self.build_info();
let info = Self::dev_build_info(&mut lock, d);
assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
info.spirv
.as_ref()
.unwrap()
.to_nir(
kernel,
d.screen
.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
&d.lib_clc,
)
.unwrap()
}
}