rusticl/kernel: prepare for nir caching
Signed-off-by: Karol Herbst <kherbst@redhat.com> Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15439>
This commit is contained in:
@@ -164,7 +164,6 @@ pub fn create_kernel(
|
||||
// kernel_name such as the number of arguments, the argument types are not the same for all
|
||||
// devices for which the program executable has been built.
|
||||
let devs = get_devices_with_valid_build(&p)?;
|
||||
let nirs = p.nirs(&name);
|
||||
let kernel_args: HashSet<_> = devs.iter().map(|d| p.args(d, &name)).collect();
|
||||
if kernel_args.len() != 1 {
|
||||
return Err(CL_INVALID_KERNEL_DEFINITION);
|
||||
@@ -173,7 +172,6 @@ pub fn create_kernel(
|
||||
Ok(cl_kernel::from_arc(Kernel::new(
|
||||
name,
|
||||
p,
|
||||
nirs,
|
||||
kernel_args.into_iter().next().unwrap(),
|
||||
)))
|
||||
}
|
||||
@@ -206,13 +204,11 @@ pub fn create_kernels_in_program(
|
||||
if !kernels.is_null() {
|
||||
// we just assume the client isn't stupid
|
||||
unsafe {
|
||||
let nirs = p.nirs(&name);
|
||||
kernels
|
||||
.add(num_kernels as usize)
|
||||
.write(cl_kernel::from_arc(Kernel::new(
|
||||
name,
|
||||
p.clone(),
|
||||
nirs,
|
||||
kernel_args.into_iter().next().unwrap(),
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ pub enum KernelArgValue {
|
||||
LocalMem(usize),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
#[derive(Hash, PartialEq, Eq, Clone)]
|
||||
pub enum KernelArgType {
|
||||
Constant, // for anything passed by value
|
||||
Image,
|
||||
@@ -53,7 +53,7 @@ pub enum InternalKernelArgType {
|
||||
OrderArray,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Hash, PartialEq, Eq, Clone)]
|
||||
pub struct KernelArg {
|
||||
spirv: spirv::SPIRVKernelArg,
|
||||
pub kind: KernelArgType,
|
||||
@@ -70,7 +70,7 @@ pub struct InternalKernelArg {
|
||||
}
|
||||
|
||||
impl KernelArg {
|
||||
fn from_spirv_nir(spirv: Vec<spirv::SPIRVKernelArg>, nir: &mut NirShader) -> Vec<Self> {
|
||||
fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
|
||||
let nir_arg_map: HashMap<_, _> = nir
|
||||
.variables_with_mode(
|
||||
nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
|
||||
@@ -79,7 +79,7 @@ impl KernelArg {
|
||||
.collect();
|
||||
let mut res = Vec::new();
|
||||
|
||||
for (i, s) in spirv.into_iter().enumerate() {
|
||||
for (i, s) in spirv.iter().enumerate() {
|
||||
let nir = nir_arg_map.get(&(i as i32)).unwrap();
|
||||
let kind = match s.address_qualifier {
|
||||
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
|
||||
@@ -109,7 +109,7 @@ impl KernelArg {
|
||||
};
|
||||
|
||||
res.push(Self {
|
||||
spirv: s,
|
||||
spirv: s.clone(),
|
||||
size: unsafe { glsl_get_cl_size(nir.type_) } as usize,
|
||||
// we'll update it later in the 2nd pass
|
||||
kind: kind,
|
||||
@@ -454,6 +454,43 @@ fn lower_and_optimize_nir_late(
|
||||
res
|
||||
}
|
||||
|
||||
fn convert_spirv_to_nir(
|
||||
p: &Program,
|
||||
name: &str,
|
||||
args: Vec<spirv::SPIRVKernelArg>,
|
||||
) -> (
|
||||
HashMap<Arc<Device>, NirShader>,
|
||||
Vec<KernelArg>,
|
||||
Vec<InternalKernelArg>,
|
||||
) {
|
||||
let mut nirs = HashMap::new();
|
||||
let mut args_set = HashSet::new();
|
||||
let mut internal_args_set = HashSet::new();
|
||||
|
||||
// TODO: we could run this in parallel?
|
||||
for d in p.devs_with_build() {
|
||||
let mut nir = p.to_nir(name, d);
|
||||
|
||||
lower_and_optimize_nir_pre_inputs(d, &mut nir, &d.lib_clc);
|
||||
|
||||
let mut args = KernelArg::from_spirv_nir(&args, &mut nir);
|
||||
let mut internal_args = lower_and_optimize_nir_late(d, &mut nir, args.len());
|
||||
KernelArg::assign_locations(&mut args, &mut internal_args, &mut nir);
|
||||
|
||||
args_set.insert(args);
|
||||
internal_args_set.insert(internal_args);
|
||||
nirs.insert(d.clone(), nir);
|
||||
}
|
||||
|
||||
// we want the same (internal) args for every compiled kernel, for now
|
||||
assert!(args_set.len() == 1);
|
||||
assert!(internal_args_set.len() == 1);
|
||||
let args = args_set.into_iter().next().unwrap();
|
||||
let internal_args = internal_args_set.into_iter().next().unwrap();
|
||||
|
||||
(nirs, args, internal_args)
|
||||
}
|
||||
|
||||
fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
|
||||
let val;
|
||||
(val, *buf) = (*buf).split_at(S);
|
||||
@@ -463,30 +500,15 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
|
||||
}
|
||||
|
||||
impl Kernel {
|
||||
pub fn new(
|
||||
name: String,
|
||||
prog: Arc<Program>,
|
||||
mut nirs: HashMap<Arc<Device>, NirShader>,
|
||||
args: Vec<spirv::SPIRVKernelArg>,
|
||||
) -> Arc<Kernel> {
|
||||
nirs.iter_mut()
|
||||
.for_each(|(d, n)| lower_and_optimize_nir_pre_inputs(d, n, &d.lib_clc));
|
||||
pub fn new(name: String, prog: Arc<Program>, args: Vec<spirv::SPIRVKernelArg>) -> Arc<Kernel> {
|
||||
let (mut nirs, args, internal_args) = convert_spirv_to_nir(&prog, &name, args);
|
||||
|
||||
let nir = nirs.values_mut().next().unwrap();
|
||||
let wgs = nir.workgroup_size();
|
||||
let work_group_size = [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize];
|
||||
let mut args = KernelArg::from_spirv_nir(args, nir);
|
||||
|
||||
// can't use vec!...
|
||||
let values = args.iter().map(|_| RefCell::new(None)).collect();
|
||||
let internal_args: HashSet<_> = nirs
|
||||
.iter_mut()
|
||||
.map(|(d, n)| lower_and_optimize_nir_late(d, n, args.len()))
|
||||
.collect();
|
||||
// we want the same internal args for every compiled kernel, for now
|
||||
assert!(internal_args.len() == 1);
|
||||
let mut internal_args = internal_args.into_iter().next().unwrap();
|
||||
|
||||
nirs.values_mut()
|
||||
.for_each(|n| KernelArg::assign_locations(&mut args, &mut internal_args, n));
|
||||
|
||||
Arc::new(Self {
|
||||
base: CLObjectBase::new(),
|
||||
|
||||
@@ -436,27 +436,30 @@ impl Program {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn nirs(&self, kernel: &str) -> HashMap<Arc<Device>, NirShader> {
|
||||
pub fn devs_with_build(&self) -> Vec<&Arc<Device>> {
|
||||
let mut lock = self.build_info();
|
||||
let mut res = HashMap::new();
|
||||
for d in &self.devs {
|
||||
let info = Self::dev_build_info(&mut lock, d);
|
||||
if info.status != CL_BUILD_SUCCESS as cl_build_status {
|
||||
continue;
|
||||
}
|
||||
let nir = info
|
||||
.spirv
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.to_nir(
|
||||
kernel,
|
||||
d.screen
|
||||
.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
|
||||
&d.lib_clc,
|
||||
)
|
||||
.unwrap();
|
||||
res.insert(d.clone(), nir);
|
||||
}
|
||||
res
|
||||
self.devs
|
||||
.iter()
|
||||
.filter(|d| {
|
||||
let info = Self::dev_build_info(&mut lock, d);
|
||||
info.status == CL_BUILD_SUCCESS as cl_build_status
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn to_nir(&self, kernel: &str, d: &Arc<Device>) -> NirShader {
|
||||
let mut lock = self.build_info();
|
||||
let info = Self::dev_build_info(&mut lock, d);
|
||||
assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
|
||||
info.spirv
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.to_nir(
|
||||
kernel,
|
||||
d.screen
|
||||
.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
|
||||
&d.lib_clc,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user