1922 lines
67 KiB
Rust
1922 lines
67 KiB
Rust
// Copyright 2020 Red Hat.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
use crate::api::icd::*;
|
|
use crate::core::device::*;
|
|
use crate::core::event::*;
|
|
use crate::core::memory::*;
|
|
use crate::core::platform::*;
|
|
use crate::core::program::*;
|
|
use crate::core::queue::*;
|
|
use crate::impl_cl_type_trait;
|
|
|
|
use mesa_rust::compiler::clc::*;
|
|
use mesa_rust::compiler::nir::*;
|
|
use mesa_rust::nir_pass;
|
|
use mesa_rust::pipe::context::PipeContext;
|
|
use mesa_rust::pipe::context::RWFlags;
|
|
use mesa_rust::pipe::resource::*;
|
|
use mesa_rust::pipe::screen::ResourceType;
|
|
use mesa_rust_gen::*;
|
|
use mesa_rust_util::math::*;
|
|
use mesa_rust_util::serialize::*;
|
|
use rusticl_opencl_gen::*;
|
|
use spirv::SpirvKernelInfo;
|
|
|
|
use std::borrow::Borrow;
|
|
use std::cmp;
|
|
use std::collections::HashMap;
|
|
use std::collections::HashSet;
|
|
use std::convert::TryInto;
|
|
use std::ffi::CStr;
|
|
use std::fmt::Debug;
|
|
use std::fmt::Display;
|
|
use std::ops::Index;
|
|
use std::ops::Not;
|
|
use std::os::raw::c_void;
|
|
use std::ptr;
|
|
use std::slice;
|
|
use std::sync::Arc;
|
|
use std::sync::Mutex;
|
|
use std::sync::MutexGuard;
|
|
use std::sync::Weak;
|
|
|
|
// According to the CL spec we are not allowed to let any cl_kernel object hold any references on
|
|
// its arguments as this might make it unfeasible for applications to free the backing memory of
|
|
// memory objects allocated with `CL_USE_HOST_PTR`.
|
|
//
|
|
// However those arguments might temporarily get referenced by event objects, so we'll use Weak in
|
|
// order to upgrade the reference when needed. It's also safer to use Weak over raw pointers,
|
|
// because it makes it impossible to run into use-after-free issues.
|
|
//
|
|
// Technically we also need to do it for samplers, but there it's kinda pointless to take a weak
|
|
// reference as samplers don't have the same host_ptr or any similar problems as cl_mem objects.
|
|
#[derive(Clone)]
|
|
pub enum KernelArgValue {
|
|
None,
|
|
/// cl_ext_buffer_device_address
|
|
BDA(u64),
|
|
SVM(usize),
|
|
Buffer(Weak<Buffer>),
|
|
Constant(Vec<u8>),
|
|
Image(Weak<Image>),
|
|
LocalMem(usize),
|
|
Sampler(Arc<Sampler>),
|
|
}
|
|
|
|
#[repr(u8)]
|
|
#[derive(Hash, PartialEq, Eq, Clone, Copy)]
|
|
pub enum KernelArgType {
|
|
Constant(/* size */ u16), // for anything passed by value
|
|
Image,
|
|
RWImage,
|
|
Sampler,
|
|
Texture,
|
|
MemGlobal,
|
|
MemConstant,
|
|
MemLocal,
|
|
}
|
|
|
|
impl KernelArgType {
|
|
fn deserialize(blob: &mut blob_reader) -> Option<Self> {
|
|
// SAFETY: we get 0 on an overrun, but we verify that later and act accordingly.
|
|
let res = match unsafe { blob_read_uint8(blob) } {
|
|
0 => {
|
|
// SAFETY: same here
|
|
let size = unsafe { blob_read_uint16(blob) };
|
|
KernelArgType::Constant(size)
|
|
}
|
|
1 => KernelArgType::Image,
|
|
2 => KernelArgType::RWImage,
|
|
3 => KernelArgType::Sampler,
|
|
4 => KernelArgType::Texture,
|
|
5 => KernelArgType::MemGlobal,
|
|
6 => KernelArgType::MemConstant,
|
|
7 => KernelArgType::MemLocal,
|
|
_ => return None,
|
|
};
|
|
|
|
blob.overrun.not().then_some(res)
|
|
}
|
|
|
|
fn serialize(&self, blob: &mut blob) {
|
|
unsafe {
|
|
match self {
|
|
KernelArgType::Constant(size) => {
|
|
blob_write_uint8(blob, 0);
|
|
blob_write_uint16(blob, *size)
|
|
}
|
|
KernelArgType::Image => blob_write_uint8(blob, 1),
|
|
KernelArgType::RWImage => blob_write_uint8(blob, 2),
|
|
KernelArgType::Sampler => blob_write_uint8(blob, 3),
|
|
KernelArgType::Texture => blob_write_uint8(blob, 4),
|
|
KernelArgType::MemGlobal => blob_write_uint8(blob, 5),
|
|
KernelArgType::MemConstant => blob_write_uint8(blob, 6),
|
|
KernelArgType::MemLocal => blob_write_uint8(blob, 7),
|
|
};
|
|
}
|
|
}
|
|
|
|
fn is_opaque(&self) -> bool {
|
|
matches!(
|
|
self,
|
|
KernelArgType::Image
|
|
| KernelArgType::RWImage
|
|
| KernelArgType::Texture
|
|
| KernelArgType::Sampler
|
|
)
|
|
}
|
|
}
|
|
|
|
#[derive(Hash, PartialEq, Eq, Clone)]
|
|
enum CompiledKernelArgType {
|
|
APIArg(usize),
|
|
ConstantBuffer,
|
|
GlobalWorkOffsets,
|
|
GlobalWorkSize,
|
|
PrintfBuffer,
|
|
InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
|
|
FormatArray,
|
|
OrderArray,
|
|
WorkDim,
|
|
WorkGroupOffsets,
|
|
NumWorkgroups,
|
|
}
|
|
|
|
#[derive(Hash, PartialEq, Eq, Clone)]
|
|
pub struct KernelArg {
|
|
spirv: spirv::SPIRVKernelArg,
|
|
pub kind: KernelArgType,
|
|
pub dead: bool,
|
|
}
|
|
|
|
impl KernelArg {
|
|
fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
|
|
let nir_arg_map: HashMap<_, _> = nir
|
|
.variables_with_mode(
|
|
nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
|
|
)
|
|
.map(|v| (v.data.location, v))
|
|
.collect();
|
|
let mut res = Vec::new();
|
|
|
|
for (i, s) in spirv.iter().enumerate() {
|
|
let nir = nir_arg_map.get(&(i as i32)).unwrap();
|
|
let kind = match s.address_qualifier {
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
|
|
if unsafe { glsl_type_is_sampler(nir.type_) } {
|
|
KernelArgType::Sampler
|
|
} else {
|
|
let size = unsafe { glsl_get_cl_size(nir.type_) } as u16;
|
|
// nir types of non opaque types are never sized 0
|
|
KernelArgType::Constant(size)
|
|
}
|
|
}
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
|
|
KernelArgType::MemConstant
|
|
}
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
|
|
KernelArgType::MemLocal
|
|
}
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
|
|
if unsafe { glsl_type_is_image(nir.type_) } {
|
|
let access = nir.data.access();
|
|
if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
|
|
KernelArgType::Texture
|
|
} else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
|
|
KernelArgType::Image
|
|
} else {
|
|
KernelArgType::RWImage
|
|
}
|
|
} else {
|
|
KernelArgType::MemGlobal
|
|
}
|
|
}
|
|
};
|
|
|
|
res.push(Self {
|
|
spirv: s.clone(),
|
|
// we'll update it later in the 2nd pass
|
|
kind: kind,
|
|
dead: true,
|
|
});
|
|
}
|
|
res
|
|
}
|
|
|
|
fn serialize(args: &[Self], blob: &mut blob) {
|
|
unsafe {
|
|
blob_write_uint16(blob, args.len() as u16);
|
|
|
|
for arg in args {
|
|
arg.spirv.serialize(blob);
|
|
blob_write_uint8(blob, arg.dead.into());
|
|
arg.kind.serialize(blob);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
|
|
// SAFETY: we check the overrun status, blob_read returns 0 in such a case.
|
|
let len = unsafe { blob_read_uint16(blob) } as usize;
|
|
let mut res = Vec::with_capacity(len);
|
|
|
|
for _ in 0..len {
|
|
let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
|
|
// SAFETY: we check the overrun status
|
|
let dead = unsafe { blob_read_uint8(blob) } != 0;
|
|
let kind = KernelArgType::deserialize(blob)?;
|
|
|
|
res.push(Self {
|
|
spirv: spirv,
|
|
kind: kind,
|
|
dead: dead,
|
|
});
|
|
}
|
|
|
|
blob.overrun.not().then_some(res)
|
|
}
|
|
}
|
|
|
|
#[derive(Hash, PartialEq, Eq, Clone)]
|
|
struct CompiledKernelArg {
|
|
kind: CompiledKernelArgType,
|
|
/// The binding for image/sampler args, the offset into the input buffer
|
|
/// for anything else.
|
|
offset: usize,
|
|
dead: bool,
|
|
}
|
|
|
|
impl CompiledKernelArg {
|
|
fn assign_locations(compiled_args: &mut [Self], nir: &mut NirShader) {
|
|
for var in nir.variables_with_mode(
|
|
nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
|
|
) {
|
|
let arg = &mut compiled_args[var.data.location as usize];
|
|
let t = var.type_;
|
|
|
|
arg.dead = false;
|
|
arg.offset = if unsafe {
|
|
glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
|
|
} {
|
|
var.data.binding
|
|
} else {
|
|
var.data.driver_location
|
|
} as usize;
|
|
}
|
|
}
|
|
|
|
fn serialize(args: &[Self], blob: &mut blob) {
|
|
unsafe {
|
|
blob_write_uint16(blob, args.len() as u16);
|
|
for arg in args {
|
|
blob_write_uint32(blob, arg.offset as u32);
|
|
blob_write_uint8(blob, arg.dead.into());
|
|
match arg.kind {
|
|
CompiledKernelArgType::ConstantBuffer => blob_write_uint8(blob, 0),
|
|
CompiledKernelArgType::GlobalWorkOffsets => blob_write_uint8(blob, 1),
|
|
CompiledKernelArgType::PrintfBuffer => blob_write_uint8(blob, 2),
|
|
CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
|
|
blob_write_uint8(blob, 3);
|
|
blob_write_uint8(blob, norm.into());
|
|
blob_write_uint32(blob, addr_mode);
|
|
blob_write_uint32(blob, filter_mode)
|
|
}
|
|
CompiledKernelArgType::FormatArray => blob_write_uint8(blob, 4),
|
|
CompiledKernelArgType::OrderArray => blob_write_uint8(blob, 5),
|
|
CompiledKernelArgType::WorkDim => blob_write_uint8(blob, 6),
|
|
CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
|
|
CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
|
|
CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
|
|
CompiledKernelArgType::APIArg(idx) => {
|
|
blob_write_uint8(blob, 10);
|
|
blob_write_uint32(blob, idx as u32)
|
|
}
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
|
|
unsafe {
|
|
let len = blob_read_uint16(blob) as usize;
|
|
let mut res = Vec::with_capacity(len);
|
|
|
|
for _ in 0..len {
|
|
let offset = blob_read_uint32(blob) as usize;
|
|
let dead = blob_read_uint8(blob) != 0;
|
|
|
|
let kind = match blob_read_uint8(blob) {
|
|
0 => CompiledKernelArgType::ConstantBuffer,
|
|
1 => CompiledKernelArgType::GlobalWorkOffsets,
|
|
2 => CompiledKernelArgType::PrintfBuffer,
|
|
3 => {
|
|
let norm = blob_read_uint8(blob) != 0;
|
|
let addr_mode = blob_read_uint32(blob);
|
|
let filter_mode = blob_read_uint32(blob);
|
|
CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
|
|
}
|
|
4 => CompiledKernelArgType::FormatArray,
|
|
5 => CompiledKernelArgType::OrderArray,
|
|
6 => CompiledKernelArgType::WorkDim,
|
|
7 => CompiledKernelArgType::WorkGroupOffsets,
|
|
8 => CompiledKernelArgType::NumWorkgroups,
|
|
9 => CompiledKernelArgType::GlobalWorkSize,
|
|
10 => {
|
|
let idx = blob_read_uint32(blob) as usize;
|
|
CompiledKernelArgType::APIArg(idx)
|
|
}
|
|
_ => return None,
|
|
};
|
|
|
|
res.push(Self {
|
|
kind: kind,
|
|
offset: offset,
|
|
dead: dead,
|
|
});
|
|
}
|
|
|
|
Some(res)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, PartialEq, Eq, Hash)]
|
|
pub struct KernelInfo {
|
|
pub args: Vec<KernelArg>,
|
|
pub attributes_string: String,
|
|
work_group_size: [usize; 3],
|
|
work_group_size_hint: [u32; 3],
|
|
subgroup_size: usize,
|
|
num_subgroups: usize,
|
|
}
|
|
|
|
/// Wraps around a compute state object which is safe to share between pipe_contexts.
|
|
pub struct SharedCSOWrapper {
|
|
cso_ptr: *mut c_void,
|
|
dev: &'static Device,
|
|
}
|
|
|
|
impl SharedCSOWrapper {
|
|
/// # Safety
|
|
///
|
|
/// The returned value is only safe to be executed on a pipe_context when the device supports
|
|
/// shareable shaders.
|
|
unsafe fn new(dev: &'static Device, nir: &NirShader) -> Self {
|
|
let cso_ptr = dev
|
|
.helper_ctx()
|
|
.create_compute_state(nir, nir.shared_size());
|
|
|
|
Self {
|
|
cso_ptr: cso_ptr,
|
|
dev: dev,
|
|
}
|
|
}
|
|
|
|
/// # Safety
|
|
///
|
|
/// `self` needs to live until another CSOWrapper is bound to `ctx`
|
|
pub unsafe fn bind_to_ctx(&self, ctx: &PipeContext) {
|
|
// SAFETY: We make it the callers responsibility to uphold the safety requirements.
|
|
unsafe {
|
|
ctx.bind_compute_state(self.cso_ptr);
|
|
}
|
|
}
|
|
|
|
fn get_cso_info(&self) -> pipe_compute_state_object_info {
|
|
self.dev.helper_ctx().compute_state_info(self.cso_ptr)
|
|
}
|
|
}
|
|
|
|
impl Drop for SharedCSOWrapper {
|
|
fn drop(&mut self) {
|
|
self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
|
|
}
|
|
}
|
|
|
|
pub enum KernelDevStateVariant {
|
|
Cso(SharedCSOWrapper),
|
|
Nir(NirShader),
|
|
}
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
|
pub enum NirKernelVariant {
|
|
/// Can be used under any circumstance.
|
|
Default,
|
|
|
|
/// Optimized variant making the following assumptions:
|
|
/// - global_id_offsets are 0
|
|
/// - workgroup_offsets are 0
|
|
/// - local_size is info.local_size_hint
|
|
Optimized,
|
|
}
|
|
|
|
impl Display for NirKernelVariant {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
// this simply prints the enum name, so that's fine
|
|
Debug::fmt(self, f)
|
|
}
|
|
}
|
|
|
|
pub struct NirKernelBuilds {
|
|
default_build: NirKernelBuild,
|
|
optimized: Option<NirKernelBuild>,
|
|
/// merged info with worst case values
|
|
info: pipe_compute_state_object_info,
|
|
}
|
|
|
|
impl Index<NirKernelVariant> for NirKernelBuilds {
|
|
type Output = NirKernelBuild;
|
|
|
|
fn index(&self, index: NirKernelVariant) -> &Self::Output {
|
|
match index {
|
|
NirKernelVariant::Default => &self.default_build,
|
|
NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl NirKernelBuilds {
|
|
fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
|
|
let mut info = default_build.info;
|
|
if let Some(build) = &optimized {
|
|
info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
|
|
info.simd_sizes &= build.info.simd_sizes;
|
|
info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
|
|
info.preferred_simd_size =
|
|
cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
|
|
}
|
|
|
|
Self {
|
|
default_build: default_build,
|
|
optimized: optimized,
|
|
info: info,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct NirKernelBuild {
|
|
nir_or_cso: KernelDevStateVariant,
|
|
constant_buffer: Option<PipeResourceOwned>,
|
|
info: pipe_compute_state_object_info,
|
|
shared_size: u64,
|
|
printf_info: Option<NirPrintfInfo>,
|
|
compiled_args: Vec<CompiledKernelArg>,
|
|
}
|
|
|
|
// SAFETY: `CSOWrapper` is only safe to use if the device supports `pipe_caps.shareable_shaders` and
|
|
// we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
|
|
unsafe impl Send for NirKernelBuild {}
|
|
unsafe impl Sync for NirKernelBuild {}
|
|
|
|
impl NirKernelBuild {
|
|
fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
|
|
// SAFETY: we only use the cso when dev supports shareable shaders, otherwise we just
|
|
// extract some info and throw it away, which is safe.
|
|
let cso = unsafe { SharedCSOWrapper::new(dev, &out.nir) };
|
|
let info = cso.get_cso_info();
|
|
let cb = Self::create_nir_constant_buffer(dev, &out.nir);
|
|
let shared_size = out.nir.shared_size() as u64;
|
|
let printf_info = out.nir.take_printf_info();
|
|
|
|
let nir_or_cso = if !dev.shareable_shaders() {
|
|
KernelDevStateVariant::Nir(out.nir)
|
|
} else {
|
|
KernelDevStateVariant::Cso(cso)
|
|
};
|
|
|
|
NirKernelBuild {
|
|
nir_or_cso: nir_or_cso,
|
|
constant_buffer: cb,
|
|
info: info,
|
|
shared_size: shared_size,
|
|
printf_info: printf_info,
|
|
compiled_args: out.compiled_args,
|
|
}
|
|
}
|
|
|
|
fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<PipeResourceOwned> {
|
|
let buf = nir.get_constant_buffer();
|
|
let len = buf.len() as u32;
|
|
|
|
if len > 0 {
|
|
// TODO bind as constant buffer
|
|
let res = dev
|
|
.screen()
|
|
.resource_create_buffer(len, ResourceType::Immutable, PIPE_BIND_GLOBAL, 0)
|
|
.unwrap();
|
|
|
|
dev.helper_ctx()
|
|
.exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
|
|
.wait();
|
|
|
|
Some(res)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn nir_or_cso(&self) -> &KernelDevStateVariant {
|
|
&self.nir_or_cso
|
|
}
|
|
}
|
|
|
|
pub struct Kernel {
|
|
pub base: CLObjectBase<CL_INVALID_KERNEL>,
|
|
pub prog: Arc<Program>,
|
|
pub name: String,
|
|
values: Mutex<Vec<Option<KernelArgValue>>>,
|
|
pub bdas: Mutex<Vec<cl_mem_device_address_ext>>,
|
|
pub svms: Mutex<HashSet<usize>>,
|
|
builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
|
|
pub kernel_info: Arc<KernelInfo>,
|
|
}
|
|
|
|
impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
|
|
|
|
fn create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]>
|
|
where
|
|
T: std::convert::TryFrom<usize> + Copy,
|
|
{
|
|
let mut res = [val; 3];
|
|
for (i, v) in vals.iter().enumerate() {
|
|
res[i] = (*v).try_into().or(Err(CL_OUT_OF_RESOURCES))?;
|
|
}
|
|
|
|
Ok(res)
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct CompilationResult {
|
|
nir: NirShader,
|
|
compiled_args: Vec<CompiledKernelArg>,
|
|
}
|
|
|
|
impl CompilationResult {
|
|
fn deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self> {
|
|
let nir = NirShader::deserialize(
|
|
reader,
|
|
d.screen()
|
|
.nir_shader_compiler_options(mesa_shader_stage::MESA_SHADER_COMPUTE),
|
|
)?;
|
|
let compiled_args = CompiledKernelArg::deserialize(reader)?;
|
|
|
|
Some(Self {
|
|
nir: nir,
|
|
compiled_args,
|
|
})
|
|
}
|
|
|
|
fn serialize(&self, blob: &mut blob) {
|
|
self.nir.serialize(blob);
|
|
CompiledKernelArg::serialize(&self.compiled_args, blob);
|
|
}
|
|
}
|
|
|
|
fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
|
|
let nir_options = unsafe {
|
|
&*dev
|
|
.screen
|
|
.nir_shader_compiler_options(mesa_shader_stage::MESA_SHADER_COMPUTE)
|
|
};
|
|
|
|
while {
|
|
let mut progress = false;
|
|
|
|
progress |= nir_pass!(nir, nir_copy_prop);
|
|
progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
|
|
progress |= nir_pass!(nir, nir_opt_dead_write_vars);
|
|
|
|
if nir_options.lower_to_scalar {
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_alu_to_scalar,
|
|
nir_options.lower_to_scalar_filter,
|
|
ptr::null(),
|
|
);
|
|
nir_pass!(nir, nir_lower_phis_to_scalar, None, ptr::null());
|
|
}
|
|
|
|
progress |= nir_pass!(nir, nir_opt_deref);
|
|
if has_explicit_types {
|
|
progress |= nir_pass!(nir, nir_opt_memcpy);
|
|
}
|
|
progress |= nir_pass!(nir, nir_opt_dce);
|
|
progress |= nir_pass!(nir, nir_opt_undef);
|
|
progress |= nir_pass!(nir, nir_opt_constant_folding);
|
|
progress |= nir_pass!(nir, nir_opt_cse);
|
|
nir_pass!(nir, nir_split_var_copies);
|
|
progress |= nir_pass!(nir, nir_lower_var_copies);
|
|
progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
|
|
nir_pass!(nir, nir_lower_alu);
|
|
progress |= nir_pass!(nir, nir_opt_phi_precision);
|
|
progress |= nir_pass!(nir, nir_opt_algebraic);
|
|
progress |= nir_pass!(nir, nir_opt_algebraic_integer_promotion);
|
|
progress |= nir_pass!(
|
|
nir,
|
|
nir_opt_if,
|
|
nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
|
|
);
|
|
progress |= nir_pass!(nir, nir_opt_dead_cf);
|
|
progress |= nir_pass!(nir, nir_opt_remove_phis);
|
|
// we don't want to be too aggressive here, but it kills a bit of CFG
|
|
let peephole_select_options = nir_opt_peephole_select_options {
|
|
limit: 8,
|
|
indirect_load_ok: true,
|
|
expensive_alu_ok: true,
|
|
..Default::default()
|
|
};
|
|
progress |= nir_pass!(nir, nir_opt_peephole_select, &peephole_select_options);
|
|
progress |= nir_pass!(
|
|
nir,
|
|
nir_lower_vec3_to_vec4,
|
|
nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
|
|
);
|
|
|
|
if nir_options.max_unroll_iterations != 0 {
|
|
progress |= nir_pass!(nir, nir_opt_loop_unroll);
|
|
}
|
|
nir.sweep_mem();
|
|
progress
|
|
} {}
|
|
}
|
|
|
|
/// # Safety
|
|
///
|
|
/// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
|
|
unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
|
|
// SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
|
|
let var_type = unsafe { (*var).type_ };
|
|
// SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
|
|
// properly aligned.
|
|
unsafe {
|
|
!glsl_type_is_image(var_type)
|
|
&& !glsl_type_is_texture(var_type)
|
|
&& !glsl_type_is_sampler(var_type)
|
|
}
|
|
}
|
|
|
|
fn compile_nir_to_args(
|
|
dev: &Device,
|
|
mut nir: NirShader,
|
|
args: &[spirv::SPIRVKernelArg],
|
|
lib_clc: &NirShader,
|
|
) -> (Vec<KernelArg>, NirShader) {
|
|
// this is a hack until we support fp16 properly and check for denorms inside vstore/vload_half
|
|
nir.preserve_fp16_denorms();
|
|
|
|
// Set to rtne for now until drivers are able to report their preferred rounding mode, that also
|
|
// matches what we report via the API.
|
|
nir.set_fp_rounding_mode_rtne();
|
|
|
|
nir_pass!(nir, nir_scale_fdiv);
|
|
nir.set_workgroup_size_variable_if_zero();
|
|
nir.structurize();
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_variable_initializers,
|
|
nir_variable_mode::nir_var_function_temp
|
|
);
|
|
|
|
while {
|
|
let mut progress = false;
|
|
nir_pass!(nir, nir_split_var_copies);
|
|
progress |= nir_pass!(nir, nir_copy_prop);
|
|
progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
|
|
progress |= nir_pass!(nir, nir_opt_dead_write_vars);
|
|
progress |= nir_pass!(nir, nir_opt_deref);
|
|
progress |= nir_pass!(nir, nir_opt_dce);
|
|
progress |= nir_pass!(nir, nir_opt_undef);
|
|
progress |= nir_pass!(nir, nir_opt_constant_folding);
|
|
progress |= nir_pass!(nir, nir_opt_cse);
|
|
progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
|
|
progress |= nir_pass!(nir, nir_opt_algebraic);
|
|
progress
|
|
} {}
|
|
nir.inline(lib_clc);
|
|
nir.cleanup_functions();
|
|
// that should free up tons of memory
|
|
nir.sweep_mem();
|
|
|
|
nir_pass!(nir, nir_dedup_inline_samplers);
|
|
|
|
let printf_opts = nir_lower_printf_options {
|
|
max_buffer_size: dev.printf_buffer_size() as u32,
|
|
..Default::default()
|
|
};
|
|
nir_pass!(nir, nir_lower_printf, &printf_opts);
|
|
|
|
opt_nir(&mut nir, dev, false);
|
|
|
|
(KernelArg::from_spirv_nir(args, &mut nir), nir)
|
|
}
|
|
|
|
fn compile_nir_prepare_for_variants(
|
|
dev: &Device,
|
|
nir: &mut NirShader,
|
|
compiled_args: &mut Vec<CompiledKernelArg>,
|
|
) {
|
|
// assign locations for inline samplers.
|
|
// IMPORTANT: this needs to happen before nir_remove_dead_variables.
|
|
let mut last_loc = -1;
|
|
for v in nir
|
|
.variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
|
|
{
|
|
if unsafe { !glsl_type_is_sampler(v.type_) } {
|
|
last_loc = v.data.location;
|
|
continue;
|
|
}
|
|
let s = unsafe { v.data.anon_1.sampler };
|
|
if s.is_inline_sampler() != 0 {
|
|
last_loc += 1;
|
|
v.data.location = last_loc;
|
|
|
|
compiled_args.push(CompiledKernelArg {
|
|
kind: CompiledKernelArgType::InlineSampler(Sampler::nir_to_cl(
|
|
s.addressing_mode(),
|
|
s.filter_mode(),
|
|
s.normalized_coordinates(),
|
|
)),
|
|
offset: 0,
|
|
dead: true,
|
|
});
|
|
} else {
|
|
last_loc = v.data.location;
|
|
}
|
|
}
|
|
|
|
let dv_opts = nir_remove_dead_variables_options {
|
|
can_remove_var: Some(can_remove_var),
|
|
..Default::default()
|
|
};
|
|
|
|
nir_pass!(
|
|
nir,
|
|
nir_remove_dead_variables,
|
|
nir_variable_mode::nir_var_uniform
|
|
| nir_variable_mode::nir_var_image
|
|
| nir_variable_mode::nir_var_mem_constant
|
|
| nir_variable_mode::nir_var_mem_shared
|
|
| nir_variable_mode::nir_var_function_temp,
|
|
&dv_opts,
|
|
);
|
|
|
|
nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_cl_images,
|
|
!dev.images_as_deref(),
|
|
!dev.samplers_as_deref(),
|
|
);
|
|
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_vars_to_explicit_types,
|
|
nir_variable_mode::nir_var_mem_constant,
|
|
Some(glsl_get_cl_type_size_align),
|
|
);
|
|
|
|
// has to run before adding internal kernel arguments
|
|
nir.extract_constant_initializers();
|
|
|
|
// needed to convert variables to load intrinsics
|
|
nir_pass!(nir, nir_lower_system_values);
|
|
|
|
// Run here so we can decide if it makes sense to compile a variant, e.g. read system values.
|
|
nir.gather_info();
|
|
}
|
|
|
|
fn compile_nir_variant(
|
|
res: &mut CompilationResult,
|
|
dev: &Device,
|
|
variant: NirKernelVariant,
|
|
args: &[KernelArg],
|
|
name: &str,
|
|
) {
|
|
let mut lower_state = rusticl_lower_state::default();
|
|
let compiled_args = &mut res.compiled_args;
|
|
let nir = &mut res.nir;
|
|
|
|
let address_bits_ptr_type;
|
|
let address_bits_base_type;
|
|
let global_address_format;
|
|
let shared_address_format;
|
|
|
|
if dev.address_bits() == 64 {
|
|
address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
|
|
address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
|
|
global_address_format = nir_address_format::nir_address_format_64bit_global;
|
|
shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
|
|
} else {
|
|
address_bits_ptr_type = unsafe { glsl_uint_type() };
|
|
address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
|
|
global_address_format = nir_address_format::nir_address_format_32bit_global;
|
|
shared_address_format = nir_address_format::nir_address_format_32bit_offset;
|
|
}
|
|
|
|
let nir_options = unsafe {
|
|
&*dev
|
|
.screen
|
|
.nir_shader_compiler_options(mesa_shader_stage::MESA_SHADER_COMPUTE)
|
|
};
|
|
|
|
if variant == NirKernelVariant::Optimized {
|
|
let wgsh = nir.workgroup_size_hint();
|
|
if wgsh != [0; 3] {
|
|
nir.set_workgroup_size(wgsh);
|
|
}
|
|
}
|
|
|
|
let mut compute_options = nir_lower_compute_system_values_options::default();
|
|
compute_options.set_has_global_size(true);
|
|
if variant != NirKernelVariant::Optimized {
|
|
compute_options.set_has_base_global_invocation_id(true);
|
|
compute_options.set_has_base_workgroup_id(true);
|
|
}
|
|
nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
|
|
nir.gather_info();
|
|
|
|
let mut add_var = |nir: &mut NirShader,
|
|
var_loc: &mut usize,
|
|
kind: CompiledKernelArgType,
|
|
glsl_type: *const glsl_type,
|
|
name| {
|
|
*var_loc = compiled_args.len();
|
|
compiled_args.push(CompiledKernelArg {
|
|
kind: kind,
|
|
offset: 0,
|
|
dead: true,
|
|
});
|
|
nir.add_var(
|
|
nir_variable_mode::nir_var_uniform,
|
|
glsl_type,
|
|
*var_loc,
|
|
name,
|
|
);
|
|
};
|
|
|
|
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
|
|
debug_assert_ne!(variant, NirKernelVariant::Optimized);
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.base_global_invoc_id_loc,
|
|
CompiledKernelArgType::GlobalWorkOffsets,
|
|
unsafe { glsl_vector_type(address_bits_base_type, 3) },
|
|
c"base_global_invocation_id",
|
|
)
|
|
}
|
|
|
|
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.global_size_loc,
|
|
CompiledKernelArgType::GlobalWorkSize,
|
|
unsafe { glsl_vector_type(address_bits_base_type, 3) },
|
|
c"global_size",
|
|
)
|
|
}
|
|
|
|
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
|
|
debug_assert_ne!(variant, NirKernelVariant::Optimized);
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.base_workgroup_id_loc,
|
|
CompiledKernelArgType::WorkGroupOffsets,
|
|
unsafe { glsl_vector_type(address_bits_base_type, 3) },
|
|
c"base_workgroup_id",
|
|
);
|
|
}
|
|
|
|
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.num_workgroups_loc,
|
|
CompiledKernelArgType::NumWorkgroups,
|
|
unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
|
|
c"num_workgroups",
|
|
);
|
|
}
|
|
|
|
if nir.has_constant() {
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.const_buf_loc,
|
|
CompiledKernelArgType::ConstantBuffer,
|
|
address_bits_ptr_type,
|
|
c"constant_buffer_addr",
|
|
);
|
|
}
|
|
if nir.has_printf() {
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.printf_buf_loc,
|
|
CompiledKernelArgType::PrintfBuffer,
|
|
address_bits_ptr_type,
|
|
c"printf_buffer_addr",
|
|
);
|
|
}
|
|
|
|
if nir.num_images() > 0 || nir.num_textures() > 0 {
|
|
let count = nir.num_images() + nir.num_textures();
|
|
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.format_arr_loc,
|
|
CompiledKernelArgType::FormatArray,
|
|
unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
|
|
c"image_formats",
|
|
);
|
|
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.order_arr_loc,
|
|
CompiledKernelArgType::OrderArray,
|
|
unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
|
|
c"image_orders",
|
|
);
|
|
}
|
|
|
|
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
|
|
add_var(
|
|
nir,
|
|
&mut lower_state.work_dim_loc,
|
|
CompiledKernelArgType::WorkDim,
|
|
unsafe { glsl_uint8_t_type() },
|
|
c"work_dim",
|
|
);
|
|
}
|
|
|
|
// need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
|
|
// memory
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_vars_to_explicit_types,
|
|
nir_variable_mode::nir_var_mem_shared
|
|
| nir_variable_mode::nir_var_function_temp
|
|
| nir_variable_mode::nir_var_shader_temp
|
|
| nir_variable_mode::nir_var_uniform
|
|
| nir_variable_mode::nir_var_mem_global
|
|
| nir_variable_mode::nir_var_mem_generic,
|
|
Some(glsl_get_cl_type_size_align),
|
|
);
|
|
|
|
opt_nir(nir, dev, true);
|
|
nir_pass!(nir, nir_lower_memcpy);
|
|
|
|
let dv_opts = nir_remove_dead_variables_options {
|
|
can_remove_var: Some(can_remove_var),
|
|
..Default::default()
|
|
};
|
|
|
|
// we might have got rid of more function_temp or shared memory
|
|
nir.reset_scratch_size();
|
|
nir.reset_shared_size();
|
|
nir_pass!(
|
|
nir,
|
|
nir_remove_dead_variables,
|
|
nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
|
|
&dv_opts,
|
|
);
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_vars_to_explicit_types,
|
|
nir_variable_mode::nir_var_function_temp
|
|
| nir_variable_mode::nir_var_mem_shared
|
|
| nir_variable_mode::nir_var_mem_generic,
|
|
Some(glsl_get_cl_type_size_align),
|
|
);
|
|
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_explicit_io,
|
|
nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
|
|
global_address_format,
|
|
);
|
|
|
|
nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
|
|
nir_pass!(
|
|
nir,
|
|
nir_lower_explicit_io,
|
|
nir_variable_mode::nir_var_mem_shared
|
|
| nir_variable_mode::nir_var_function_temp
|
|
| nir_variable_mode::nir_var_uniform,
|
|
shared_address_format,
|
|
);
|
|
|
|
if nir_options.lower_int64_options.0 != 0 && !nir_options.late_lower_int64 {
|
|
nir_pass!(nir, nir_lower_int64);
|
|
}
|
|
|
|
if nir_options.lower_uniforms_to_ubo {
|
|
nir_pass!(nir, rusticl_lower_inputs);
|
|
}
|
|
|
|
nir_pass!(nir, nir_lower_convert_alu_types, None);
|
|
|
|
opt_nir(nir, dev, true);
|
|
|
|
/* before passing it into drivers, assign locations as drivers might remove nir_variables or
|
|
* other things we depend on
|
|
*/
|
|
CompiledKernelArg::assign_locations(compiled_args, nir);
|
|
|
|
/* update the has_variable_shared_mem info as we might have DCEed all of them */
|
|
nir.set_has_variable_shared_mem(compiled_args.iter().any(|arg| {
|
|
if let CompiledKernelArgType::APIArg(idx) = arg.kind {
|
|
args[idx].kind == KernelArgType::MemLocal && !arg.dead
|
|
} else {
|
|
false
|
|
}
|
|
}));
|
|
|
|
if Platform::dbg().nir {
|
|
eprintln!("=== Printing nir variant '{variant}' for '{name}' before driver finalization");
|
|
nir.print();
|
|
}
|
|
|
|
if dev.screen.finalize_nir(nir) {
|
|
if Platform::dbg().nir {
|
|
eprintln!(
|
|
"=== Printing nir variant '{variant}' for '{name}' after driver finalization"
|
|
);
|
|
nir.print();
|
|
}
|
|
}
|
|
|
|
nir_pass!(nir, nir_opt_dce);
|
|
nir.sweep_mem();
|
|
}
|
|
|
|
fn compile_nir_remaining(
|
|
dev: &Device,
|
|
mut nir: NirShader,
|
|
args: &[KernelArg],
|
|
name: &str,
|
|
) -> (CompilationResult, Option<CompilationResult>) {
|
|
// add all API kernel args
|
|
let mut compiled_args: Vec<_> = (0..args.len())
|
|
.map(|idx| CompiledKernelArg {
|
|
kind: CompiledKernelArgType::APIArg(idx),
|
|
offset: 0,
|
|
dead: true,
|
|
})
|
|
.collect();
|
|
|
|
compile_nir_prepare_for_variants(dev, &mut nir, &mut compiled_args);
|
|
if Platform::dbg().nir {
|
|
eprintln!("=== Printing nir for '{name}' before specialization");
|
|
nir.print();
|
|
}
|
|
|
|
let mut default_build = CompilationResult {
|
|
nir: nir,
|
|
compiled_args: compiled_args,
|
|
};
|
|
|
|
// check if we even want to compile a variant before cloning the compilation state
|
|
let has_wgs_hint = default_build.nir.workgroup_size_variable()
|
|
&& default_build.nir.workgroup_size_hint() != [0; 3];
|
|
let has_offsets = default_build
|
|
.nir
|
|
.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
|
|
|
|
let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
|
|
.then(|| default_build.clone());
|
|
|
|
compile_nir_variant(
|
|
&mut default_build,
|
|
dev,
|
|
NirKernelVariant::Default,
|
|
args,
|
|
name,
|
|
);
|
|
if let Some(optimized) = &mut optimized {
|
|
compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args, name);
|
|
}
|
|
|
|
(default_build, optimized)
|
|
}
|
|
|
|
pub struct SPIRVToNirResult {
|
|
pub kernel_info: KernelInfo,
|
|
pub nir_kernel_builds: NirKernelBuilds,
|
|
}
|
|
|
|
impl SPIRVToNirResult {
|
|
fn new(
|
|
dev: &'static Device,
|
|
kernel_info: &clc_kernel_info,
|
|
args: Vec<KernelArg>,
|
|
default_build: CompilationResult,
|
|
optimized: Option<CompilationResult>,
|
|
) -> Self {
|
|
// TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
|
|
// indirections yet.
|
|
let nir = &default_build.nir;
|
|
let wgs = nir.workgroup_size();
|
|
let subgroup_size = nir.subgroup_size();
|
|
let num_subgroups = nir.num_subgroups();
|
|
|
|
let default_build = NirKernelBuild::new(dev, default_build);
|
|
let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
|
|
|
|
let kernel_info = KernelInfo {
|
|
args: args,
|
|
attributes_string: kernel_info.attribute_str(),
|
|
work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
|
|
work_group_size_hint: kernel_info.local_size_hint,
|
|
subgroup_size: subgroup_size as usize,
|
|
num_subgroups: num_subgroups as usize,
|
|
};
|
|
|
|
Self {
|
|
kernel_info: kernel_info,
|
|
nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
|
|
}
|
|
}
|
|
|
|
fn deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self> {
|
|
let mut reader = blob_reader::default();
|
|
unsafe {
|
|
blob_reader_init(&mut reader, bin.as_ptr().cast(), bin.len());
|
|
}
|
|
|
|
let args = KernelArg::deserialize(&mut reader)?;
|
|
let default_build = CompilationResult::deserialize(&mut reader, d)?;
|
|
|
|
// SAFETY: on overrun this returns 0
|
|
let optimized = match unsafe { blob_read_uint8(&mut reader) } {
|
|
0 => None,
|
|
_ => Some(CompilationResult::deserialize(&mut reader, d)?),
|
|
};
|
|
|
|
reader
|
|
.overrun
|
|
.not()
|
|
.then(|| SPIRVToNirResult::new(d, kernel_info, args, default_build, optimized))
|
|
}
|
|
|
|
// we can't use Self here as the nir shader might be compiled to a cso already and we can't
|
|
// cache that.
|
|
fn serialize(
|
|
blob: &mut blob,
|
|
args: &[KernelArg],
|
|
default_build: &CompilationResult,
|
|
optimized: &Option<CompilationResult>,
|
|
) {
|
|
KernelArg::serialize(args, blob);
|
|
default_build.serialize(blob);
|
|
match optimized {
|
|
Some(variant) => {
|
|
unsafe { blob_write_uint8(blob, 1) };
|
|
variant.serialize(blob);
|
|
}
|
|
None => unsafe {
|
|
blob_write_uint8(blob, 0);
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(super) fn convert_spirv_to_nir(
|
|
build: &DeviceProgramBuild,
|
|
name: &str,
|
|
args: &[spirv::SPIRVKernelArg],
|
|
spec_constants: &HashMap<u32, nir_const_value>,
|
|
dev: &'static Device,
|
|
) -> SPIRVToNirResult {
|
|
let cache = dev.screen().shader_cache();
|
|
let key = build.hash_key(cache.as_ref(), name, spec_constants);
|
|
let spirv_info = build.kernel_info(name).unwrap();
|
|
|
|
cache
|
|
.as_ref()
|
|
.and_then(|cache| cache.get(&mut key?))
|
|
.and_then(|entry| SPIRVToNirResult::deserialize(&entry, dev, spirv_info))
|
|
.unwrap_or_else(|| {
|
|
let nir = build.to_nir(name, dev, spec_constants);
|
|
|
|
if Platform::dbg().nir {
|
|
eprintln!("=== Printing nir for '{name}' after spirv_to_nir");
|
|
nir.print();
|
|
}
|
|
|
|
let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
|
|
let (default_build, optimized) = compile_nir_remaining(dev, nir, &args, name);
|
|
|
|
for build in [Some(&default_build), optimized.as_ref()].into_iter() {
|
|
let Some(build) = build else {
|
|
continue;
|
|
};
|
|
|
|
for arg in &build.compiled_args {
|
|
if let CompiledKernelArgType::APIArg(idx) = arg.kind {
|
|
args[idx].dead &= arg.dead;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(cache) = cache {
|
|
let mut blob = blob::default();
|
|
unsafe {
|
|
blob_init(&mut blob);
|
|
SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
|
|
let bin = slice::from_raw_parts(blob.data, blob.size);
|
|
cache.put(bin, &mut key.unwrap());
|
|
blob_finish(&mut blob);
|
|
}
|
|
}
|
|
|
|
SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
|
|
})
|
|
}
|
|
|
|
fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
|
|
let val;
|
|
(val, *buf) = (*buf).split_at(S);
|
|
// we split of 4 bytes and convert to [u8; 4], so this should be safe
|
|
// use split_array_ref once it's stable
|
|
val.try_into().unwrap()
|
|
}
|
|
|
|
/// Helper class to build an execution environment for a single kernel invocation.
|
|
struct KernelExecBuilder<'a> {
|
|
dev: &'static Device,
|
|
input: Vec<u8>,
|
|
resource_info: Vec<(&'a PipeResource, usize)>,
|
|
}
|
|
|
|
impl<'a> KernelExecBuilder<'a> {
|
|
fn new(dev: &'static Device) -> Self {
|
|
Self {
|
|
dev: dev,
|
|
input: Vec::new(),
|
|
resource_info: Vec::new(),
|
|
}
|
|
}
|
|
|
|
fn add_global(&mut self, res: &'a PipeResourceOwned, offset: usize) {
|
|
self.resource_info.push((res.borrow(), self.input.len()));
|
|
self.add_pointer(offset as u64);
|
|
}
|
|
|
|
fn add_pointer(&mut self, address: u64) {
|
|
if self.dev.address_bits() == 64 {
|
|
let address: u64 = address;
|
|
self.input.extend_from_slice(&address.to_ne_bytes());
|
|
} else {
|
|
let address: u32 = address as u32;
|
|
self.input.extend_from_slice(&address.to_ne_bytes());
|
|
}
|
|
}
|
|
|
|
fn add_sysval(&mut self, vals: &[usize; 3]) {
|
|
if self.dev.address_bits() == 64 {
|
|
self.input
|
|
.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u64)) });
|
|
} else {
|
|
self.input
|
|
.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u32)) });
|
|
}
|
|
}
|
|
|
|
fn add_values(&mut self, value: &[u8]) {
|
|
self.input.extend_from_slice(value);
|
|
}
|
|
}
|
|
|
|
impl Kernel {
|
|
pub fn new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel> {
|
|
let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap());
|
|
let builds = prog_build
|
|
.builds_by_device
|
|
.iter()
|
|
.filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, Arc::clone(k))))
|
|
.collect();
|
|
|
|
let values = vec![None; kernel_info.args.len()];
|
|
Arc::new(Self {
|
|
base: CLObjectBase::new(RusticlTypes::Kernel),
|
|
prog: prog,
|
|
name: name,
|
|
values: Mutex::new(values),
|
|
bdas: Mutex::new(Vec::new()),
|
|
svms: Mutex::new(HashSet::new()),
|
|
builds: builds,
|
|
kernel_info: kernel_info,
|
|
})
|
|
}
|
|
|
|
pub fn suggest_local_size(
|
|
&self,
|
|
d: &Device,
|
|
work_dim: usize,
|
|
grid: &mut [usize],
|
|
block: &mut [usize],
|
|
) {
|
|
let mut threads = self.max_threads_per_block(d);
|
|
let dim_threads = d.max_block_sizes();
|
|
let subgroups = self.preferred_simd_size(d);
|
|
|
|
for i in 0..work_dim {
|
|
let t = cmp::min(threads, dim_threads[i]);
|
|
let gcd = gcd(t, grid[i]);
|
|
|
|
block[i] = gcd;
|
|
grid[i] /= gcd;
|
|
|
|
// update limits
|
|
threads /= block[i];
|
|
}
|
|
|
|
// if we didn't fill the subgroup we can do a bit better if we have threads remaining
|
|
let total_threads = block.iter().take(work_dim).product::<usize>();
|
|
if threads != 1 && total_threads < subgroups {
|
|
for i in 0..work_dim {
|
|
if grid[i] * total_threads < threads && grid[i] * block[i] <= dim_threads[i] {
|
|
block[i] *= grid[i];
|
|
grid[i] = 1;
|
|
// can only do it once as nothing is cleanly divisible
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) {
|
|
if !block.contains(&0) {
|
|
for i in 0..3 {
|
|
// we already made sure everything is fine
|
|
grid[i] /= block[i] as usize;
|
|
}
|
|
return;
|
|
}
|
|
|
|
let mut usize_block = [0usize; 3];
|
|
for i in 0..3 {
|
|
usize_block[i] = block[i] as usize;
|
|
}
|
|
|
|
self.suggest_local_size(d, 3, grid, &mut usize_block);
|
|
|
|
for i in 0..3 {
|
|
block[i] = usize_block[i] as u32;
|
|
}
|
|
}
|
|
|
|
// the painful part is, that host threads are allowed to modify the kernel object once it was
|
|
// enqueued, so return a closure with all req data included.
|
|
pub fn launch(
|
|
self: &Arc<Self>,
|
|
q: &Arc<Queue>,
|
|
work_dim: u32,
|
|
block: &[usize],
|
|
grid: &[usize],
|
|
offsets: &[usize],
|
|
) -> CLResult<EventSig> {
|
|
// Clone all the data we need to execute this kernel
|
|
let kernel_info = Arc::clone(&self.kernel_info);
|
|
let arg_values = self.arg_values().clone();
|
|
let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
|
|
let mut bdas = self.bdas.lock().unwrap().clone();
|
|
let svms = self.svms.lock().unwrap().clone();
|
|
|
|
let mut buffer_arcs = HashMap::new();
|
|
let mut image_arcs = HashMap::new();
|
|
|
|
// need to preprocess buffer and image arguments so we hold a strong reference until the
|
|
// event was processed.
|
|
for arg in arg_values.iter() {
|
|
match arg {
|
|
Some(KernelArgValue::Buffer(buffer)) => {
|
|
buffer_arcs.insert(
|
|
// we use the ptr as the key, and also cast it to usize so we don't need to
|
|
// deal with Send + Sync here.
|
|
buffer.as_ptr() as usize,
|
|
buffer.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
|
|
);
|
|
}
|
|
Some(KernelArgValue::Image(image)) => {
|
|
image_arcs.insert(
|
|
image.as_ptr() as usize,
|
|
image.upgrade().ok_or(CL_INVALID_KERNEL_ARGS)?,
|
|
);
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
// operations we want to report errors to the clients
|
|
let mut block = create_kernel_arr::<u32>(block, 1)?;
|
|
let mut grid = create_kernel_arr::<usize>(grid, 1)?;
|
|
let offsets = create_kernel_arr::<usize>(offsets, 0)?;
|
|
|
|
let api_grid = grid;
|
|
|
|
self.optimize_local_size(q.device, &mut grid, &mut block);
|
|
|
|
Ok(Box::new(move |cl_ctx, ctx| {
|
|
let hw_max_grid = ctx.dev.max_grid_size();
|
|
|
|
let variant = if offsets == [0; 3]
|
|
&& grid[0] <= hw_max_grid[0]
|
|
&& grid[1] <= hw_max_grid[1]
|
|
&& grid[2] <= hw_max_grid[2]
|
|
&& (kernel_info.work_group_size_hint == [0; 3]
|
|
|| block == kernel_info.work_group_size_hint)
|
|
{
|
|
NirKernelVariant::Optimized
|
|
} else {
|
|
NirKernelVariant::Default
|
|
};
|
|
|
|
let nir_kernel_build = &nir_kernel_builds[variant];
|
|
let mut workgroup_id_offset_loc = None;
|
|
let mut exec_builder = KernelExecBuilder::new(ctx.dev);
|
|
// Set it once so we get the alignment padding right
|
|
let static_local_size: u64 = nir_kernel_build.shared_size;
|
|
let mut variable_local_size: u64 = static_local_size;
|
|
let printf_size = ctx.dev.printf_buffer_size() as u32;
|
|
let mut samplers = Vec::new();
|
|
let mut iviews = Vec::new();
|
|
let mut sviews = Vec::new();
|
|
let mut tex_formats: Vec<u16> = Vec::new();
|
|
let mut tex_orders: Vec<u16> = Vec::new();
|
|
let mut img_formats: Vec<u16> = Vec::new();
|
|
let mut img_orders: Vec<u16> = Vec::new();
|
|
|
|
let null_ptr;
|
|
let null_ptr_v3;
|
|
if ctx.dev.address_bits() == 64 {
|
|
null_ptr = [0u8; 8].as_slice();
|
|
null_ptr_v3 = [0u8; 24].as_slice();
|
|
} else {
|
|
null_ptr = [0u8; 4].as_slice();
|
|
null_ptr_v3 = [0u8; 12].as_slice();
|
|
};
|
|
|
|
let mut printf_buf = None;
|
|
if nir_kernel_build.printf_info.is_some() {
|
|
let buf = ctx
|
|
.dev
|
|
.screen
|
|
.resource_create_buffer(printf_size, ResourceType::Staging, PIPE_BIND_GLOBAL, 0)
|
|
.unwrap();
|
|
|
|
let init_data: [u8; 1] = [4];
|
|
ctx.buffer_subdata(&buf, 0, init_data.as_ptr().cast(), init_data.len() as u32);
|
|
|
|
printf_buf = Some(buf);
|
|
}
|
|
|
|
// translate SVM pointers to their base first
|
|
let mut svms: HashSet<_> = svms
|
|
.into_iter()
|
|
.filter_map(|svm_pointer| Some(cl_ctx.find_svm_alloc(svm_pointer)?.0 as usize))
|
|
.collect();
|
|
|
|
for arg in &nir_kernel_build.compiled_args {
|
|
let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
|
|
kernel_info.args[idx].kind.is_opaque()
|
|
} else {
|
|
false
|
|
};
|
|
|
|
if !is_opaque && arg.offset > exec_builder.input.len() {
|
|
exec_builder.input.resize(arg.offset, 0);
|
|
}
|
|
|
|
match arg.kind {
|
|
CompiledKernelArgType::APIArg(idx) => {
|
|
let api_arg = &kernel_info.args[idx];
|
|
let Some(value) = &arg_values[idx] else {
|
|
continue;
|
|
};
|
|
|
|
match value {
|
|
KernelArgValue::Constant(c) => exec_builder.add_values(c),
|
|
KernelArgValue::BDA(address) => {
|
|
bdas.push(*address);
|
|
if !api_arg.dead {
|
|
exec_builder.add_pointer(*address);
|
|
}
|
|
}
|
|
KernelArgValue::Buffer(buffer) => {
|
|
let buffer = &buffer_arcs[&(buffer.as_ptr() as usize)];
|
|
let rw = if api_arg.spirv.address_qualifier
|
|
== clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT
|
|
{
|
|
RWFlags::RD
|
|
} else {
|
|
RWFlags::RW
|
|
};
|
|
|
|
// if the argument is dead, based on what kind of memory it is, we
|
|
// might need to migrate and make it available to the invocation
|
|
// regardless.
|
|
if api_arg.dead {
|
|
if let Some(address) = buffer.dev_address(ctx.dev) {
|
|
let _ = buffer.get_res_for_access(ctx, rw)?;
|
|
bdas.push(address.get());
|
|
} else if buffer.is_svm() {
|
|
let _ = buffer.get_res_for_access(ctx, rw)?;
|
|
svms.insert(buffer.host_ptr() as usize);
|
|
}
|
|
} else {
|
|
let res = buffer.get_res_for_access(ctx, rw)?;
|
|
exec_builder.add_global(res, buffer.offset());
|
|
}
|
|
}
|
|
&KernelArgValue::SVM(handle) => {
|
|
// get the base address so we deduplicate properly
|
|
if let Some((base, _)) = cl_ctx.find_svm_alloc(handle) {
|
|
svms.insert(base as usize);
|
|
}
|
|
|
|
if !api_arg.dead {
|
|
exec_builder.add_pointer(handle as u64);
|
|
}
|
|
}
|
|
KernelArgValue::Image(image) => {
|
|
let image = &image_arcs[&(image.as_ptr() as usize)];
|
|
let (formats, orders) = if api_arg.kind == KernelArgType::Image {
|
|
iviews.push(image.image_view(ctx, false)?);
|
|
(&mut img_formats, &mut img_orders)
|
|
} else if api_arg.kind == KernelArgType::RWImage {
|
|
iviews.push(image.image_view(ctx, true)?);
|
|
(&mut img_formats, &mut img_orders)
|
|
} else {
|
|
sviews.push(image.sampler_view(ctx.ctx)?);
|
|
(&mut tex_formats, &mut tex_orders)
|
|
};
|
|
|
|
assert!(arg.offset >= formats.len());
|
|
|
|
formats.resize(arg.offset, 0);
|
|
orders.resize(arg.offset, 0);
|
|
|
|
formats.push(image.image_format.image_channel_data_type as u16);
|
|
orders.push(image.image_format.image_channel_order as u16);
|
|
}
|
|
KernelArgValue::LocalMem(size) => {
|
|
// TODO 32 bit
|
|
let pot = cmp::min(*size, 0x80);
|
|
variable_local_size = variable_local_size
|
|
.next_multiple_of(pot.next_power_of_two() as u64);
|
|
if ctx.dev.address_bits() == 64 {
|
|
let variable_local_size: [u8; 8] =
|
|
variable_local_size.to_ne_bytes();
|
|
exec_builder.add_values(&variable_local_size);
|
|
} else {
|
|
let variable_local_size: [u8; 4] =
|
|
(variable_local_size as u32).to_ne_bytes();
|
|
exec_builder.add_values(&variable_local_size);
|
|
}
|
|
variable_local_size += *size as u64;
|
|
}
|
|
KernelArgValue::Sampler(sampler) => {
|
|
samplers.push(sampler.pipe());
|
|
}
|
|
KernelArgValue::None => {
|
|
if !arg.dead
|
|
&& matches!(
|
|
api_arg.kind,
|
|
KernelArgType::MemGlobal | KernelArgType::MemConstant
|
|
)
|
|
{
|
|
exec_builder.add_values(null_ptr);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
CompiledKernelArgType::ConstantBuffer => {
|
|
assert!(nir_kernel_build.constant_buffer.is_some());
|
|
let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
|
|
exec_builder.add_global(res, 0);
|
|
}
|
|
CompiledKernelArgType::GlobalWorkOffsets => {
|
|
exec_builder.add_sysval(&offsets);
|
|
}
|
|
CompiledKernelArgType::WorkGroupOffsets => {
|
|
workgroup_id_offset_loc = Some(exec_builder.input.len());
|
|
exec_builder.add_values(null_ptr_v3);
|
|
}
|
|
CompiledKernelArgType::GlobalWorkSize => {
|
|
exec_builder.add_sysval(&api_grid);
|
|
}
|
|
CompiledKernelArgType::PrintfBuffer => {
|
|
let res = printf_buf.as_ref().unwrap();
|
|
exec_builder.add_global(res, 0);
|
|
}
|
|
CompiledKernelArgType::InlineSampler(cl) => {
|
|
samplers.push(Sampler::cl_to_pipe(cl));
|
|
}
|
|
CompiledKernelArgType::FormatArray => {
|
|
exec_builder.add_values(unsafe { as_byte_slice(&tex_formats) });
|
|
exec_builder.add_values(unsafe { as_byte_slice(&img_formats) });
|
|
}
|
|
CompiledKernelArgType::OrderArray => {
|
|
exec_builder.add_values(unsafe { as_byte_slice(&tex_orders) });
|
|
exec_builder.add_values(unsafe { as_byte_slice(&img_orders) });
|
|
}
|
|
CompiledKernelArgType::WorkDim => {
|
|
exec_builder.add_values(&[work_dim as u8; 1]);
|
|
}
|
|
CompiledKernelArgType::NumWorkgroups => {
|
|
exec_builder.add_values(unsafe {
|
|
as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32])
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// dedup with a HashSet
|
|
let bdas = bdas
|
|
.into_iter()
|
|
// Ignore invalid pointers as they are legal to be passed in, but illegal to
|
|
// dereference.
|
|
.filter_map(|address| cl_ctx.find_bda_alloc(ctx.dev, address))
|
|
.collect::<HashSet<_>>();
|
|
|
|
let mut bdas: Vec<_> = bdas
|
|
.iter()
|
|
.map(|buffer| buffer.get_res_for_access(ctx, RWFlags::RW))
|
|
.collect::<CLResult<_>>()?;
|
|
|
|
let svms_new = svms
|
|
.into_iter()
|
|
.filter_map(|svm| cl_ctx.copy_svm_to_dev(ctx, svm).transpose())
|
|
.collect::<CLResult<Vec<_>>>()?;
|
|
|
|
// uhhh
|
|
for svm in &svms_new {
|
|
bdas.push(svm);
|
|
}
|
|
|
|
// subtract the shader local_size as we only request something on top of that.
|
|
variable_local_size -= static_local_size;
|
|
|
|
let mut resources = Vec::with_capacity(exec_builder.resource_info.len());
|
|
let mut globals: Vec<*mut u32> = Vec::with_capacity(exec_builder.resource_info.len());
|
|
for (res, offset) in exec_builder.resource_info {
|
|
resources.push(res);
|
|
globals.push(unsafe { exec_builder.input.as_mut_ptr().byte_add(offset) }.cast());
|
|
}
|
|
|
|
ctx.bind_kernel(&nir_kernel_builds, variant)?;
|
|
ctx.bind_sampler_states(samplers);
|
|
ctx.bind_sampler_views(sviews);
|
|
ctx.bind_shader_images(&iviews);
|
|
ctx.set_global_binding(resources.as_mut_slice(), &mut globals);
|
|
|
|
for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
|
|
for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
|
|
for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
|
|
if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc {
|
|
let this_offsets =
|
|
[x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]];
|
|
|
|
if ctx.dev.address_bits() == 64 {
|
|
let val = this_offsets.map(|v| v as u64);
|
|
exec_builder.input
|
|
[workgroup_id_offset_loc..workgroup_id_offset_loc + 24]
|
|
.copy_from_slice(unsafe { as_byte_slice(&val) });
|
|
} else {
|
|
let val = this_offsets.map(|v| v as u32);
|
|
exec_builder.input
|
|
[workgroup_id_offset_loc..workgroup_id_offset_loc + 12]
|
|
.copy_from_slice(unsafe { as_byte_slice(&val) });
|
|
}
|
|
}
|
|
|
|
let this_grid = [
|
|
cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,
|
|
cmp::min(hw_max_grid[1], grid[1] - hw_max_grid[1] * y) as u32,
|
|
cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32,
|
|
];
|
|
|
|
ctx.update_cb0(&exec_builder.input)?;
|
|
ctx.launch_grid(
|
|
work_dim,
|
|
block,
|
|
this_grid,
|
|
variable_local_size as u32,
|
|
&bdas,
|
|
);
|
|
|
|
if Platform::dbg().sync_every_event {
|
|
ctx.flush().wait();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
ctx.clear_global_binding(globals.len() as u32);
|
|
|
|
ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
|
|
|
|
if let Some(printf_buf) = &printf_buf {
|
|
let tx = ctx
|
|
.buffer_map(printf_buf, 0, printf_size as i32, RWFlags::RD)
|
|
.ok_or(CL_OUT_OF_RESOURCES)?;
|
|
let mut buf: &[u8] =
|
|
unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
|
|
let length = u32::from_ne_bytes(*extract(&mut buf));
|
|
|
|
// update our slice to make sure we don't go out of bounds
|
|
buf = &buf[0..(length - 4) as usize];
|
|
if let Some(pf) = &nir_kernel_build.printf_info {
|
|
pf.u_printf(buf)
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}))
|
|
}
|
|
|
|
pub fn arg_values(&self) -> MutexGuard<'_, Vec<Option<KernelArgValue>>> {
|
|
self.values.lock().unwrap()
|
|
}
|
|
|
|
pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
|
|
self.values
|
|
.lock()
|
|
.unwrap()
|
|
.get_mut(idx)
|
|
.ok_or(CL_INVALID_ARG_INDEX)?
|
|
.replace(arg);
|
|
Ok(())
|
|
}
|
|
|
|
pub fn access_qualifier(&self, idx: usize) -> cl_kernel_arg_access_qualifier {
|
|
let aq = self.kernel_info.args[idx].spirv.access_qualifier;
|
|
|
|
if aq
|
|
== clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
|
|
| clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
|
|
{
|
|
CL_KERNEL_ARG_ACCESS_READ_WRITE
|
|
} else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
|
|
CL_KERNEL_ARG_ACCESS_READ_ONLY
|
|
} else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
|
|
CL_KERNEL_ARG_ACCESS_WRITE_ONLY
|
|
} else {
|
|
CL_KERNEL_ARG_ACCESS_NONE
|
|
}
|
|
}
|
|
|
|
pub fn address_qualifier(&self, idx: usize) -> cl_kernel_arg_address_qualifier {
|
|
match self.kernel_info.args[idx].spirv.address_qualifier {
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
|
|
CL_KERNEL_ARG_ADDRESS_PRIVATE
|
|
}
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
|
|
CL_KERNEL_ARG_ADDRESS_CONSTANT
|
|
}
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
|
|
CL_KERNEL_ARG_ADDRESS_LOCAL
|
|
}
|
|
clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
|
|
CL_KERNEL_ARG_ADDRESS_GLOBAL
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn type_qualifier(&self, idx: usize) -> cl_kernel_arg_type_qualifier {
|
|
let tq = self.kernel_info.args[idx].spirv.type_qualifier;
|
|
let zero = clc_kernel_arg_type_qualifier(0);
|
|
let mut res = CL_KERNEL_ARG_TYPE_NONE;
|
|
|
|
if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
|
|
res |= CL_KERNEL_ARG_TYPE_CONST;
|
|
}
|
|
|
|
if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
|
|
res |= CL_KERNEL_ARG_TYPE_RESTRICT;
|
|
}
|
|
|
|
if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
|
|
res |= CL_KERNEL_ARG_TYPE_VOLATILE;
|
|
}
|
|
|
|
res.into()
|
|
}
|
|
|
|
pub fn work_group_size(&self) -> [usize; 3] {
|
|
self.kernel_info.work_group_size
|
|
}
|
|
|
|
pub fn num_subgroups(&self) -> usize {
|
|
self.kernel_info.num_subgroups
|
|
}
|
|
|
|
pub fn subgroup_size(&self) -> usize {
|
|
self.kernel_info.subgroup_size
|
|
}
|
|
|
|
pub fn arg_name(&self, idx: usize) -> Option<&CStr> {
|
|
let name = &self.kernel_info.args[idx].spirv.name;
|
|
name.is_empty().not().then_some(name)
|
|
}
|
|
|
|
pub fn arg_type_name(&self, idx: usize) -> Option<&CStr> {
|
|
let type_name = &self.kernel_info.args[idx].spirv.type_name;
|
|
type_name.is_empty().not().then_some(type_name)
|
|
}
|
|
|
|
pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
|
|
self.builds.get(dev).unwrap().info.private_memory as cl_ulong
|
|
}
|
|
|
|
pub fn max_threads_per_block(&self, dev: &Device) -> usize {
|
|
self.builds.get(dev).unwrap().info.max_threads as usize
|
|
}
|
|
|
|
pub fn preferred_simd_size(&self, dev: &Device) -> usize {
|
|
self.builds.get(dev).unwrap().info.preferred_simd_size as usize
|
|
}
|
|
|
|
pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
|
|
// TODO: take alignment into account?
|
|
// this is purely informational so it shouldn't even matter
|
|
let local =
|
|
self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong;
|
|
let args: cl_ulong = self
|
|
.arg_values()
|
|
.iter()
|
|
.map(|arg| match arg {
|
|
Some(KernelArgValue::LocalMem(val)) => *val as cl_ulong,
|
|
// If the local memory size, for any pointer argument to the kernel declared with
|
|
// the __local address qualifier, is not specified, its size is assumed to be 0.
|
|
_ => 0,
|
|
})
|
|
.sum();
|
|
|
|
local + args
|
|
}
|
|
|
|
pub fn has_svm_devs(&self) -> bool {
|
|
self.prog.devs.iter().any(|dev| dev.api_svm_supported())
|
|
}
|
|
|
|
pub fn subgroup_sizes(&self, dev: &Device) -> impl ExactSizeIterator<Item = usize> + use<> {
|
|
SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes).map(|bit| 1 << bit)
|
|
}
|
|
|
|
pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
|
|
let subgroup_size = self.subgroup_size_for_block(dev, block);
|
|
if subgroup_size == 0 {
|
|
return 0;
|
|
}
|
|
|
|
let threads: usize = block.iter().product();
|
|
threads.div_ceil(subgroup_size)
|
|
}
|
|
|
|
pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
|
|
let mut subgroup_sizes = self.subgroup_sizes(dev);
|
|
|
|
// Replace with `ExactSizeIterator::is_empty()` when stable.
|
|
// See https://github.com/rust-lang/rust/issues/35428
|
|
if subgroup_sizes.len() == 0 {
|
|
return 0;
|
|
}
|
|
|
|
if subgroup_sizes.len() == 1 {
|
|
return subgroup_sizes.next().unwrap();
|
|
}
|
|
|
|
let block = [
|
|
*block.first().unwrap_or(&1) as u32,
|
|
*block.get(1).unwrap_or(&1) as u32,
|
|
*block.get(2).unwrap_or(&1) as u32,
|
|
];
|
|
|
|
// TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
|
|
match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
|
|
KernelDevStateVariant::Cso(cso) => {
|
|
dev.helper_ctx()
|
|
.compute_state_subgroup_size(cso.cso_ptr, &block) as usize
|
|
}
|
|
_ => {
|
|
panic!()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Clone for Kernel {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
base: CLObjectBase::new(RusticlTypes::Kernel),
|
|
prog: Arc::clone(&self.prog),
|
|
name: self.name.clone(),
|
|
values: Mutex::new(self.arg_values().clone()),
|
|
bdas: Mutex::new(self.bdas.lock().unwrap().clone()),
|
|
svms: Mutex::new(self.svms.lock().unwrap().clone()),
|
|
builds: self.builds.clone(),
|
|
kernel_info: Arc::clone(&self.kernel_info),
|
|
}
|
|
}
|
|
}
|