From 3ca448a5495ca861d6fe6b6d65fd91ff20da8ee1 Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Wed, 31 Jul 2024 11:32:05 -0500 Subject: [PATCH] nak: Replace the guts of Srcs/DstsAsSlice with a new AsSlice trait This new trait is way more generic and shareable. It does mean a bit of gymnastics with traits to keep from retyping the whole compiler but the result is something we can potentially share with other compilers. Reviewed-by: Christian Gmeiner Part-of: --- src/nouveau/compiler/nak/ir.rs | 184 ++++++++++++++++++---------- src/nouveau/compiler/nak/ir_proc.rs | 112 ++++++++--------- 2 files changed, 174 insertions(+), 122 deletions(-) diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 6f9e2dc7bf7..04857c00133 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -1398,6 +1398,30 @@ impl fmt::Display for Src { } } +pub enum AttrList { + Array(&'static [T]), + Uniform(T), +} + +impl Index for AttrList { + type Output = T; + + fn index(&self, idx: usize) -> &T { + match self { + AttrList::Array(arr) => &arr[idx], + AttrList::Uniform(typ) => typ, + } + } +} + +pub trait AsSlice { + type Attr; + + fn as_slice(&self) -> &[T]; + fn as_mut_slice(&mut self) -> &mut [T]; + fn attrs(&self) -> AttrList; +} + #[repr(u8)] #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum SrcType { @@ -1419,28 +1443,20 @@ impl SrcType { const DEFAULT: SrcType = SrcType::GPR; } -pub enum TypeList { - Array(&'static [T]), - Uniform(T), -} +pub type SrcTypeList = AttrList; -impl Index for TypeList { - type Output = T; - - fn index(&self, idx: usize) -> &T { - match self { - TypeList::Array(arr) => &arr[idx], - TypeList::Uniform(typ) => typ, - } +pub trait SrcsAsSlice: AsSlice { + fn srcs_as_slice(&self) -> &[Src] { + self.as_slice() } -} -pub type SrcTypeList = TypeList; + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + self.as_mut_slice() + } -pub trait SrcsAsSlice { - fn srcs_as_slice(&self) -> &[Src]; - fn srcs_as_mut_slice(&mut self) -> &mut [Src]; - fn src_types(&self) -> SrcTypeList; + fn src_types(&self) -> SrcTypeList { + self.attrs() + } fn src_idx(&self, src: &Src) -> usize { let r = self.srcs_as_slice().as_ptr_range(); @@ -1449,6 +1465,8 @@ pub trait SrcsAsSlice { } } +impl> SrcsAsSlice for T {} + fn all_dsts_uniform(dsts: &[Dst]) -> bool { let mut uniform = None; for dst in dsts { @@ -1481,12 +1499,20 @@ impl DstType { const DEFAULT: DstType = DstType::Vec; } -pub type DstTypeList = TypeList; +pub type DstTypeList = AttrList; -pub trait DstsAsSlice { - fn dsts_as_slice(&self) -> &[Dst]; - fn dsts_as_mut_slice(&mut self) -> &mut [Dst]; - fn dst_types(&self) -> DstTypeList; +pub trait DstsAsSlice: AsSlice { + fn dsts_as_slice(&self) -> &[Dst] { + self.as_slice() + } + + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + self.as_mut_slice() + } + + fn dst_types(&self) -> DstTypeList { + self.attrs() + } fn dst_idx(&self, dst: &Dst) -> usize { let r = self.dsts_as_slice().as_ptr_range(); @@ -1495,6 +1521,8 @@ pub trait DstsAsSlice { } } +impl> DstsAsSlice for T {} + pub trait IsUniform { fn is_uniform(&self) -> bool; } @@ -3993,16 +4021,18 @@ pub struct OpF2F { pub integer_rnd: bool, } -impl SrcsAsSlice for OpF2F { - fn srcs_as_slice(&self) -> &[Src] { +impl AsSlice for OpF2F { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + fn as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn src_types(&self) -> SrcTypeList { + fn attrs(&self) -> SrcTypeList { let src_type = match self.src_type { FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, @@ -4012,16 +4042,18 @@ impl SrcsAsSlice for OpF2F { } } -impl DstsAsSlice for OpF2F { - fn dsts_as_slice(&self) -> &[Dst] { +impl AsSlice for OpF2F { + type Attr = DstType; + + fn as_slice(&self) -> &[Dst] { std::slice::from_ref(&self.dst) } - fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + fn as_mut_slice(&mut self) -> &mut [Dst] { std::slice::from_mut(&mut self.dst) } - fn dst_types(&self) -> DstTypeList { + fn attrs(&self) -> DstTypeList { let dst_type = match self.dst_type { FloatType::F16 => DstType::F16, FloatType::F32 => DstType::F32, @@ -4063,16 +4095,18 @@ pub struct OpF2I { pub ftz: bool, } -impl SrcsAsSlice for OpF2I { - fn srcs_as_slice(&self) -> &[Src] { +impl AsSlice for OpF2I { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + fn as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn src_types(&self) -> SrcTypeList { + fn attrs(&self) -> SrcTypeList { let src_type = match self.src_type { FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, @@ -4104,16 +4138,18 @@ pub struct OpI2F { pub rnd_mode: FRndMode, } -impl SrcsAsSlice for OpI2F { - fn srcs_as_slice(&self) -> &[Src] { +impl AsSlice for OpI2F { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + fn as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn src_types(&self) -> SrcTypeList { + fn attrs(&self) -> SrcTypeList { if self.src_type.bits() <= 32 { SrcTypeList::Uniform(SrcType::ALU) } else { @@ -4122,16 +4158,18 @@ impl SrcsAsSlice for OpI2F { } } -impl DstsAsSlice for OpI2F { - fn dsts_as_slice(&self) -> &[Dst] { +impl AsSlice for OpI2F { + type Attr = DstType; + + fn as_slice(&self) -> &[Dst] { std::slice::from_ref(&self.dst) } - fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + fn as_mut_slice(&mut self) -> &mut [Dst] { std::slice::from_mut(&mut self.dst) } - fn dst_types(&self) -> DstTypeList { + fn attrs(&self) -> DstTypeList { let dst_type = match self.dst_type { FloatType::F16 => DstType::F16, FloatType::F32 => DstType::F32, @@ -4202,16 +4240,18 @@ pub struct OpFRnd { pub ftz: bool, } -impl SrcsAsSlice for OpFRnd { - fn srcs_as_slice(&self) -> &[Src] { +impl AsSlice for OpFRnd { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + fn as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn src_types(&self) -> SrcTypeList { + fn attrs(&self) -> SrcTypeList { let src_type = match self.src_type { FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, @@ -5775,16 +5815,18 @@ impl OpPhiSrcs { } } -impl SrcsAsSlice for OpPhiSrcs { - fn srcs_as_slice(&self) -> &[Src] { +impl AsSlice for OpPhiSrcs { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { &self.srcs.b } - fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + fn as_mut_slice(&mut self) -> &mut [Src] { &mut self.srcs.b } - fn src_types(&self) -> SrcTypeList { + fn attrs(&self) -> SrcTypeList { SrcTypeList::Uniform(SrcType::GPR) } } @@ -5821,16 +5863,18 @@ impl OpPhiDsts { } } -impl DstsAsSlice for OpPhiDsts { - fn dsts_as_slice(&self) -> &[Dst] { +impl AsSlice for OpPhiDsts { + type Attr = DstType; + + fn as_slice(&self) -> &[Dst] { &self.dsts.b } - fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + fn as_mut_slice(&mut self) -> &mut [Dst] { &mut self.dsts.b } - fn dst_types(&self) -> DstTypeList { + fn attrs(&self) -> DstTypeList { DstTypeList::Uniform(DstType::Vec) } } @@ -5936,30 +5980,34 @@ impl OpParCopy { } } -impl SrcsAsSlice for OpParCopy { - fn srcs_as_slice(&self) -> &[Src] { +impl AsSlice for OpParCopy { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { &self.dsts_srcs.b } - fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + fn as_mut_slice(&mut self) -> &mut [Src] { &mut self.dsts_srcs.b } - fn src_types(&self) -> SrcTypeList { + fn attrs(&self) -> SrcTypeList { SrcTypeList::Uniform(SrcType::GPR) } } -impl DstsAsSlice for OpParCopy { - fn dsts_as_slice(&self) -> &[Dst] { +impl AsSlice for OpParCopy { + type Attr = DstType; + + fn as_slice(&self) -> &[Dst] { &self.dsts_srcs.a } - fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + fn as_mut_slice(&mut self) -> &mut [Dst] { &mut self.dsts_srcs.a } - fn dst_types(&self) -> DstTypeList { + fn attrs(&self) -> DstTypeList { DstTypeList::Uniform(DstType::Vec) } } @@ -5988,16 +6036,18 @@ pub struct OpRegOut { pub srcs: Vec, } -impl SrcsAsSlice for OpRegOut { - fn srcs_as_slice(&self) -> &[Src] { +impl AsSlice for OpRegOut { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { &self.srcs } - fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + fn as_mut_slice(&mut self) -> &mut [Src] { &mut self.srcs } - fn src_types(&self) -> SrcTypeList { + fn attrs(&self) -> SrcTypeList { SrcTypeList::Uniform(SrcType::GPR) } } diff --git a/src/nouveau/compiler/nak/ir_proc.rs b/src/nouveau/compiler/nak/ir_proc.rs index 470547af4d2..3835d0a4498 100644 --- a/src/nouveau/compiler/nak/ir_proc.rs +++ b/src/nouveau/compiler/nak/ir_proc.rs @@ -26,10 +26,10 @@ fn expr_as_usize(expr: &syn::Expr) -> usize { .expect("Failed to parse integer literal") } -fn count_type(ty: &Type, search_type: &str) -> usize { +fn count_type(ty: &Type, slice_type: &str) -> usize { match ty { syn::Type::Array(a) => { - let elems = count_type(a.elem.as_ref(), search_type); + let elems = count_type(a.elem.as_ref(), slice_type); if elems > 0 { elems * expr_as_usize(&a.len) } else { @@ -37,7 +37,7 @@ fn count_type(ty: &Type, search_type: &str) -> usize { } } syn::Type::Path(p) => { - if p.qself.is_none() && p.path.is_ident(search_type) { + if p.qself.is_none() && p.path.is_ident(slice_type) { 1 } else { 0 @@ -47,10 +47,10 @@ fn count_type(ty: &Type, search_type: &str) -> usize { } } -fn get_type_attr(field: &Field, ty_attr: &str) -> Option { +fn get_attr(field: &Field, attr_name: &str) -> Option { for attr in &field.attrs { if let Meta::List(ml) = &attr.meta { - if ml.path.is_ident(ty_attr) { + if ml.path.is_ident(attr_name) { return Some(format!("{}", ml.tokens)); } } @@ -60,25 +60,14 @@ fn get_type_attr(field: &Field, ty_attr: &str) -> Option { fn derive_as_slice( input: TokenStream, - trait_name: &str, - func_prefix: &str, - search_type: &str, + slice_type: &str, + attr_name: &str, + attr_type: &str, ) -> TokenStream { let DeriveInput { attrs, ident, data, .. } = parse_macro_input!(input); - let trait_name = Ident::new(trait_name, Span::call_site()); - let elem_type = Ident::new(search_type, Span::call_site()); - let as_slice = - Ident::new(&format!("{func_prefix}s_as_slice"), Span::call_site()); - let as_mut_slice = - Ident::new(&format!("{func_prefix}s_as_mut_slice"), Span::call_site()); - let types_fn = - Ident::new(&format!("{func_prefix}_types"), Span::call_site()); - let ty_attr = format!("{func_prefix}_type"); - let ty_type = Ident::new(&format!("{search_type}Type"), Span::call_site()); - match data { Data::Struct(s) => { let mut has_repr_c = false; @@ -99,35 +88,37 @@ fn derive_as_slice( let mut first = None; let mut count = 0_usize; let mut found_last = false; - let mut types = TokenStream2::new(); + let mut attrs = TokenStream2::new(); if let Fields::Named(named) = s.fields { for f in named.named { - let ty_count = count_type(&f.ty, search_type); - let ty = get_type_attr(&f, &ty_attr); + let f_count = count_type(&f.ty, slice_type); + let f_attr = get_attr(&f, &attr_name); - if ty_count > 0 { + if f_count > 0 { assert!( !found_last, - "All fields of type {search_type} must be consecutive", + "All fields of type {slice_type} must be consecutive", ); - let ty = if let Some(s) = ty { + let attr_type = + Ident::new(attr_type, Span::call_site()); + let f_attr = if let Some(s) = f_attr { let s = syn::parse_str::(&s).unwrap(); - quote! { #ty_type::#s, } + quote! { #attr_type::#s, } } else { - quote! { #ty_type::DEFAULT, } + quote! { #attr_type::DEFAULT, } }; first.get_or_insert(f.ident); - for _ in 0..ty_count { - types.extend(ty.clone()); + for _ in 0..f_count { + attrs.extend(f_attr.clone()); } - count += ty_count; + count += f_count; } else { assert!( - ty.is_none(), - "{ty_attr} attribute is only allowed on {search_type}" + f_attr.is_none(), + "{attr_name} attribute is only allowed on {slice_type}" ); if !first.is_none() { found_last = true; @@ -138,42 +129,49 @@ fn derive_as_slice( panic!("Fields are not named"); } - if let Some(name) = first { + let slice_type = Ident::new(slice_type, Span::call_site()); + let attr_type = Ident::new(attr_type, Span::call_site()); + if let Some(first) = first { quote! { - impl #trait_name for #ident { - fn #as_slice(&self) -> &[#elem_type] { + impl AsSlice<#slice_type> for #ident { + type Attr = #attr_type; + + fn as_slice(&self) -> &[#slice_type] { unsafe { - let first = &self.#name as *const #elem_type; + let first = &self.#first as *const #slice_type; std::slice::from_raw_parts(first, #count) } } - fn #as_mut_slice(&mut self) -> &mut [#elem_type] { + fn as_mut_slice(&mut self) -> &mut [#slice_type] { unsafe { - let first = &mut self.#name as *mut #elem_type; + let first = + &mut self.#first as *mut #slice_type; std::slice::from_raw_parts_mut(first, #count) } } - fn #types_fn(&self) -> TypeList<#ty_type> { - static TYPES: [#ty_type; #count] = [#types]; - TypeList::Array(&TYPES) + fn attrs(&self) -> AttrList { + static ATTRS: [#attr_type; #count] = [#attrs]; + AttrList::Array(&ATTRS) } } } } else { quote! { - impl #trait_name for #ident { - fn #as_slice(&self) -> &[#elem_type] { + impl AsSlice<#slice_type> for #ident { + type Attr = #attr_type; + + fn as_slice(&self) -> &[#slice_type] { &[] } - fn #as_mut_slice(&mut self) -> &mut [#elem_type] { + fn as_mut_slice(&mut self) -> &mut [#slice_type] { &mut [] } - fn #types_fn(&self) -> TypeList<#ty_type> { - TypeList::Uniform(#ty_type::DEFAULT) + fn attrs(&self) -> AttrList { + AttrList::Uniform(#attr_type::DEFAULT) } } } @@ -184,33 +182,37 @@ fn derive_as_slice( let mut as_slice_cases = TokenStream2::new(); let mut as_mut_slice_cases = TokenStream2::new(); let mut types_cases = TokenStream2::new(); + let slice_type = Ident::new(slice_type, Span::call_site()); + let attr_type = Ident::new(attr_type, Span::call_site()); for v in e.variants { let case = v.ident; as_slice_cases.extend(quote! { - #ident::#case(x) => x.#as_slice(), + #ident::#case(x) => AsSlice::<#slice_type>::as_slice(x), }); as_mut_slice_cases.extend(quote! { - #ident::#case(x) => x.#as_mut_slice(), + #ident::#case(x) => AsSlice::<#slice_type>::as_mut_slice(x), }); types_cases.extend(quote! { - #ident::#case(x) => x.#types_fn(), + #ident::#case(x) => AsSlice::<#slice_type>::attrs(x), }); } quote! { - impl #trait_name for #ident { - fn #as_slice(&self) -> &[#elem_type] { + impl AsSlice<#slice_type> for #ident { + type Attr = #attr_type; + + fn as_slice(&self) -> &[#slice_type] { match self { #as_slice_cases } } - fn #as_mut_slice(&mut self) -> &mut [#elem_type] { + fn as_mut_slice(&mut self) -> &mut [#slice_type] { match self { #as_mut_slice_cases } } - fn #types_fn(&self) -> TypeList<#ty_type> { + fn attrs(&self) -> AttrList { match self { #types_cases } @@ -225,12 +227,12 @@ fn derive_as_slice( #[proc_macro_derive(SrcsAsSlice, attributes(src_type))] pub fn derive_srcs_as_slice(input: TokenStream) -> TokenStream { - derive_as_slice(input, "SrcsAsSlice", "src", "Src") + derive_as_slice(input, "Src", "src_type", "SrcType") } #[proc_macro_derive(DstsAsSlice, attributes(dst_type))] pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream { - derive_as_slice(input, "DstsAsSlice", "dst", "Dst") + derive_as_slice(input, "Dst", "dst_type", "DstType") } #[proc_macro_derive(DisplayOp)]