From 3a16c9ab43edeed7ea9cde5bd813aee33cef7cc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Se=C3=A1n=20de=20B=C3=BArca?= Date: Thu, 3 Apr 2025 14:26:57 -0700 Subject: [PATCH] rusticl: iterate subgroup sizes only as needed Making subgroup sizes an iterator avoids collecting (and thus allocation) in cases where the values are unneeded or only the first is needed. v2: fix calculation of `SetBitIndices` iterator length Reviewed-by: @LingMan Reviewed-by: Karol Herbst Part-of: --- src/gallium/frontends/rusticl/api/device.rs | 4 ++-- src/gallium/frontends/rusticl/core/device.rs | 6 ++---- src/gallium/frontends/rusticl/core/kernel.rs | 15 ++++++++------- src/gallium/frontends/rusticl/util/math.rs | 8 +++++++- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/gallium/frontends/rusticl/api/device.rs b/src/gallium/frontends/rusticl/api/device.rs index e4433812157..feb6c9607af 100644 --- a/src/gallium/frontends/rusticl/api/device.rs +++ b/src/gallium/frontends/rusticl/api/device.rs @@ -246,7 +246,7 @@ unsafe impl CLInfo for cl_device_id { CL_DEVICE_PREFERRED_VECTOR_WIDTH_LONG => v.write::(1), CL_DEVICE_PREFERRED_VECTOR_WIDTH_SHORT => v.write::(1), CL_DEVICE_PREFERRED_WORK_GROUP_SIZE_MULTIPLE => { - v.write::(dev.subgroup_sizes()[0]) + v.write::(dev.subgroup_sizes().next().unwrap()) } CL_DEVICE_PRINTF_BUFFER_SIZE => v.write::(dev.printf_buffer_size()), CL_DEVICE_PROFILE => v.write::<&CStr>(if dev.embedded { @@ -276,7 +276,7 @@ unsafe impl CLInfo for cl_device_id { CL_DEVICE_SUB_GROUP_INDEPENDENT_FORWARD_PROGRESS => v.write::(false), CL_DEVICE_SUB_GROUP_SIZES_INTEL => { v.write::>(if dev.subgroups_supported() { - dev.subgroup_sizes() + dev.subgroup_sizes().collect() } else { vec![0; 1] }) diff --git a/src/gallium/frontends/rusticl/core/device.rs b/src/gallium/frontends/rusticl/core/device.rs index 260e162ab30..ceedb901719 100644 --- a/src/gallium/frontends/rusticl/core/device.rs +++ b/src/gallium/frontends/rusticl/core/device.rs @@ -1104,12 +1104,10 @@ impl Device { } } - pub fn subgroup_sizes(&self) -> Vec { + pub fn subgroup_sizes(&self) -> impl ExactSizeIterator { let subgroup_size = self.screen.compute_caps().subgroup_sizes; - SetBitIndices::from_msb(subgroup_size) - .map(|bit| 1 << bit) - .collect() + SetBitIndices::from_msb(subgroup_size).map(|bit| 1 << bit) } pub fn max_subgroups(&self) -> u32 { diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 53211dfdd5d..1f34233f254 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -1765,10 +1765,8 @@ impl Kernel { self.prog.devs.iter().any(|dev| dev.svm_supported()) } - pub fn subgroup_sizes(&self, dev: &Device) -> Vec { - SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes) - .map(|bit| 1 << bit) - .collect() + pub fn subgroup_sizes(&self, dev: &Device) -> impl ExactSizeIterator { + SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes).map(|bit| 1 << bit) } pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize { @@ -1782,13 +1780,16 @@ impl Kernel { } pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize { - let subgroup_sizes = self.subgroup_sizes(dev); - if subgroup_sizes.is_empty() { + let mut subgroup_sizes = self.subgroup_sizes(dev); + + // Replace with `ExactSizeIterator::is_empty()` when stable. + // See https://github.com/rust-lang/rust/issues/35428 + if subgroup_sizes.len() == 0 { return 0; } if subgroup_sizes.len() == 1 { - return subgroup_sizes[0]; + return subgroup_sizes.next().unwrap(); } let block = [ diff --git a/src/gallium/frontends/rusticl/util/math.rs b/src/gallium/frontends/rusticl/util/math.rs index c419310864c..d37efc65153 100644 --- a/src/gallium/frontends/rusticl/util/math.rs +++ b/src/gallium/frontends/rusticl/util/math.rs @@ -21,7 +21,7 @@ pub struct SetBitIndices { impl SetBitIndices { pub fn from_msb(val: T) -> Self { - Self { val: val } + Self { val } } } @@ -38,3 +38,9 @@ impl Iterator for SetBitIndices { } } } + +impl ExactSizeIterator for SetBitIndices { + fn len(&self) -> usize { + self.val.count_ones() as usize + } +}