rusticl: iterate subgroup sizes only as needed

Making subgroup sizes an iterator avoids collecting (and thus
allocation) in cases where the values are unneeded or only the first is
needed.

v2: fix calculation of `SetBitIndices<u32>` iterator length

Reviewed-by: @LingMan
Reviewed-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34389>
This commit is contained in:
Seán de Búrca
2025-04-03 14:26:57 -07:00
committed by Marge Bot
parent 0980ba8595
commit 3a16c9ab43
4 changed files with 19 additions and 14 deletions
+2 -2
View File
@@ -246,7 +246,7 @@ unsafe impl CLInfo<cl_device_info> for cl_device_id {
CL_DEVICE_PREFERRED_VECTOR_WIDTH_LONG => v.write::<cl_uint>(1),
CL_DEVICE_PREFERRED_VECTOR_WIDTH_SHORT => v.write::<cl_uint>(1),
CL_DEVICE_PREFERRED_WORK_GROUP_SIZE_MULTIPLE => {
v.write::<usize>(dev.subgroup_sizes()[0])
v.write::<usize>(dev.subgroup_sizes().next().unwrap())
}
CL_DEVICE_PRINTF_BUFFER_SIZE => v.write::<usize>(dev.printf_buffer_size()),
CL_DEVICE_PROFILE => v.write::<&CStr>(if dev.embedded {
@@ -276,7 +276,7 @@ unsafe impl CLInfo<cl_device_info> for cl_device_id {
CL_DEVICE_SUB_GROUP_INDEPENDENT_FORWARD_PROGRESS => v.write::<bool>(false),
CL_DEVICE_SUB_GROUP_SIZES_INTEL => {
v.write::<Vec<usize>>(if dev.subgroups_supported() {
dev.subgroup_sizes()
dev.subgroup_sizes().collect()
} else {
vec![0; 1]
})
+2 -4
View File
@@ -1104,12 +1104,10 @@ impl Device {
}
}
pub fn subgroup_sizes(&self) -> Vec<usize> {
pub fn subgroup_sizes(&self) -> impl ExactSizeIterator<Item = usize> {
let subgroup_size = self.screen.compute_caps().subgroup_sizes;
SetBitIndices::from_msb(subgroup_size)
.map(|bit| 1 << bit)
.collect()
SetBitIndices::from_msb(subgroup_size).map(|bit| 1 << bit)
}
pub fn max_subgroups(&self) -> u32 {
+8 -7
View File
@@ -1765,10 +1765,8 @@ impl Kernel {
self.prog.devs.iter().any(|dev| dev.svm_supported())
}
pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
.map(|bit| 1 << bit)
.collect()
pub fn subgroup_sizes(&self, dev: &Device) -> impl ExactSizeIterator<Item = usize> {
SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes).map(|bit| 1 << bit)
}
pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
@@ -1782,13 +1780,16 @@ impl Kernel {
}
pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
let subgroup_sizes = self.subgroup_sizes(dev);
if subgroup_sizes.is_empty() {
let mut subgroup_sizes = self.subgroup_sizes(dev);
// Replace with `ExactSizeIterator::is_empty()` when stable.
// See https://github.com/rust-lang/rust/issues/35428
if subgroup_sizes.len() == 0 {
return 0;
}
if subgroup_sizes.len() == 1 {
return subgroup_sizes[0];
return subgroup_sizes.next().unwrap();
}
let block = [
+7 -1
View File
@@ -21,7 +21,7 @@ pub struct SetBitIndices<T> {
impl<T> SetBitIndices<T> {
pub fn from_msb(val: T) -> Self {
Self { val: val }
Self { val }
}
}
@@ -38,3 +38,9 @@ impl Iterator for SetBitIndices<u32> {
}
}
}
impl ExactSizeIterator for SetBitIndices<u32> {
fn len(&self) -> usize {
self.val.count_ones() as usize
}
}