diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 86379c552c2..ab78938e572 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -374,7 +374,10 @@ pub struct SSAValue { impl SSAValue { /// Returns an SSA value with the given register file and index pub fn new(file: RegFile, idx: u32) -> SSAValue { - assert!(idx > 0 && idx < (1 << 29) - 2); + assert!( + idx > 0 + && idx < (1 << 29) - u32::try_from(SSARef::LARGE_SIZE).unwrap() + ); let mut packed = idx; assert!(u8::from(file) < 8); packed |= u32::from(u8::from(file)) << 29; @@ -407,6 +410,57 @@ impl fmt::Display for SSAValue { } } +#[derive(Clone, Eq, Hash, PartialEq)] +struct SSAValueArray { + v: [SSAValue; SIZE], +} + +impl SSAValueArray { + /// Returns a new SSA reference + #[inline] + fn new(comps: &[SSAValue]) -> Self { + assert!(comps.len() > 0 && comps.len() <= SIZE); + let mut r = Self { + v: [SSAValue { + packed: NonZeroU32::MAX, + }; SIZE], + }; + for i in 0..comps.len() { + r.v[i] = comps[i]; + } + if comps.len() < SIZE { + r.v[SIZE - 1].packed = + (comps.len() as u32).wrapping_neg().try_into().unwrap(); + } + r + } + + fn comps(&self) -> u8 { + let size: u8 = SIZE.try_into().unwrap(); + if self.v[SIZE - 1].packed.get() >= u32::MAX - (u32::from(size) - 1) { + self.v[SIZE - 1].packed.get().wrapping_neg() as u8 + } else { + size + } + } +} + +impl Deref for SSAValueArray { + type Target = [SSAValue]; + + fn deref(&self) -> &[SSAValue] { + let comps = usize::from(self.comps()); + &self.v[..comps] + } +} + +impl DerefMut for SSAValueArray { + fn deref_mut(&mut self) -> &mut [SSAValue] { + let comps = usize::from(self.comps()); + &mut self.v[..comps] + } +} + /// A reference to one or more SSA values /// /// Because each SSA value represents a single 1 or 32-bit scalar, we need a way @@ -416,40 +470,46 @@ impl fmt::Display for SSAValue { /// registers, with the base register aligned to the number of values, aligned /// to the next power of two. /// -/// An SSA reference can reference between 1 and 4 SSA values. It dereferences +/// An SSA reference can reference between 1 and 16 SSA values. It dereferences /// to a slice for easy access to individual SSA values. The structure is /// designed so that is always 16B, regardless of how many SSA values are -/// referenced so it's easy and fairly cheap to copy around and embed in other +/// referenced so it's easy and fairly cheap to clone and embed in other /// structures. #[derive(Clone, Eq, Hash, PartialEq)] +enum SSARefInner { + Small(SSAValueArray<{ SSARef::SMALL_SIZE }>), + Large(Box>), +} +#[derive(Clone, Eq, Hash, PartialEq)] pub struct SSARef { - v: [SSAValue; 4], + v: SSARefInner, } +#[cfg(target_arch = "x86_64")] +const _: () = { + debug_assert!(std::mem::size_of::() == 16); +}; + impl SSARef { + const SMALL_SIZE: usize = 4; + const LARGE_SIZE: usize = 16; + /// Returns a new SSA reference #[inline] - fn new(comps: &[SSAValue]) -> SSARef { - assert!(comps.len() > 0 && comps.len() <= 4); - let mut r = SSARef { - v: [SSAValue { - packed: NonZeroU32::MAX, - }; 4], - }; - for i in 0..comps.len() { - r.v[i] = comps[i]; + pub fn new(comps: &[SSAValue]) -> SSARef { + SSARef { + v: if comps.len() > Self::SMALL_SIZE { + SSARefInner::Large(Box::new(SSAValueArray::new(comps))) + } else { + SSARefInner::Small(SSAValueArray::new(comps)) + }, } - if comps.len() < 4 { - r.v[3].packed = - (comps.len() as u32).wrapping_neg().try_into().unwrap(); - } - r } fn from_iter(mut it: impl ExactSizeIterator) -> Self { let len = it.len(); - assert!(len > 0 && len <= 4); - let v: [SSAValue; 4] = array::from_fn(|_| { + assert!(len > 0 && len <= Self::LARGE_SIZE); + let v: [SSAValue; Self::LARGE_SIZE] = array::from_fn(|_| { it.next().unwrap_or(SSAValue { packed: NonZeroU32::MAX, }) @@ -459,18 +519,17 @@ impl SSARef { /// Returns the number of components in this SSA reference pub fn comps(&self) -> u8 { - if self.v[3].packed.get() >= u32::MAX - 2 { - self.v[3].packed.get().wrapping_neg() as u8 - } else { - 4 + match &self.v { + SSARefInner::Small(x) => x.comps(), + SSARefInner::Large(x) => x.comps(), } } pub fn file(&self) -> Option { let comps = usize::from(self.comps()); - let file = self.v[0].file(); + let file = self[0].file(); for i in 1..comps { - if self.v[i].file() != file { + if self[i].file() != file { return None; } } @@ -496,7 +555,7 @@ impl SSARef { } pub fn is_predicate(&self) -> bool { - if self.v[0].is_predicate() { + if self[0].is_predicate() { true } else { for ssa in &self[..] { @@ -511,15 +570,19 @@ impl Deref for SSARef { type Target = [SSAValue]; fn deref(&self) -> &[SSAValue] { - let comps = usize::from(self.comps()); - &self.v[..comps] + match &self.v { + SSARefInner::Small(x) => x.deref(), + SSARefInner::Large(x) => x.deref(), + } } } impl DerefMut for SSARef { fn deref_mut(&mut self) -> &mut [SSAValue] { - let comps = usize::from(self.comps()); - &mut self.v[..comps] + match &mut self.v { + SSARefInner::Small(x) => x.deref_mut(), + SSARefInner::Large(x) => x.deref_mut(), + } } } @@ -529,7 +592,7 @@ impl TryFrom<&[SSAValue]> for SSARef { fn try_from(comps: &[SSAValue]) -> Result { if comps.len() == 0 { Err("Empty vector") - } else if comps.len() > 4 { + } else if comps.len() > Self::LARGE_SIZE { Err("Too many vector components") } else { Ok(SSARef::new(comps)) diff --git a/src/nouveau/compiler/nak/ir_tests.rs b/src/nouveau/compiler/nak/ir_tests.rs new file mode 100644 index 00000000000..526eb2283c8 --- /dev/null +++ b/src/nouveau/compiler/nak/ir_tests.rs @@ -0,0 +1,15 @@ +// Copyright © 2025 Valve Corporation +// SPDX-License-Identifier: MIT +use crate::ir::*; + +#[test] +fn test_ssa_ref_round_trip() { + for len in 1..16 { + let vec: Vec<_> = (0..len) + .map(|i| SSAValue::new(RegFile::GPR, 1337 ^ i ^ len)) + .collect(); + + let ssa_ref = SSARef::new(&vec); + assert!(&ssa_ref[..] == &vec[..]); + } +} diff --git a/src/nouveau/compiler/nak/lib.rs b/src/nouveau/compiler/nak/lib.rs index fb6237f4a26..dc57a0b4f61 100644 --- a/src/nouveau/compiler/nak/lib.rs +++ b/src/nouveau/compiler/nak/lib.rs @@ -44,5 +44,8 @@ mod hw_tests; #[cfg(test)] mod hw_runner; +#[cfg(test)] +mod ir_tests; + #[cfg(test)] mod nvdisasm_tests;