nak: Replace the guts of Srcs/DstsAsSlice with a new AsSlice trait

This new trait is way more generic and shareable.  It does mean a bit of
gymnastics with traits to keep from retyping the whole compiler but the
result is something we can potentially share with other compilers.

Reviewed-by: Christian Gmeiner <cgmeiner@igalia.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30443>
This commit is contained in:
Faith Ekstrand
2024-07-31 11:32:05 -05:00
committed by Marge Bot
parent bc58881b9f
commit 3ca448a549
2 changed files with 174 additions and 122 deletions

View File

@@ -1398,6 +1398,30 @@ impl fmt::Display for Src {
}
}
pub enum AttrList<T: 'static> {
Array(&'static [T]),
Uniform(T),
}
impl<T: 'static> Index<usize> for AttrList<T> {
type Output = T;
fn index(&self, idx: usize) -> &T {
match self {
AttrList::Array(arr) => &arr[idx],
AttrList::Uniform(typ) => typ,
}
}
}
pub trait AsSlice<T> {
type Attr;
fn as_slice(&self) -> &[T];
fn as_mut_slice(&mut self) -> &mut [T];
fn attrs(&self) -> AttrList<Self::Attr>;
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SrcType {
@@ -1419,28 +1443,20 @@ impl SrcType {
const DEFAULT: SrcType = SrcType::GPR;
}
pub enum TypeList<T: 'static> {
Array(&'static [T]),
Uniform(T),
}
pub type SrcTypeList = AttrList<SrcType>;
impl<T: 'static> Index<usize> for TypeList<T> {
type Output = T;
fn index(&self, idx: usize) -> &T {
match self {
TypeList::Array(arr) => &arr[idx],
TypeList::Uniform(typ) => typ,
}
pub trait SrcsAsSlice: AsSlice<Src, Attr = SrcType> {
fn srcs_as_slice(&self) -> &[Src] {
self.as_slice()
}
}
pub type SrcTypeList = TypeList<SrcType>;
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
self.as_mut_slice()
}
pub trait SrcsAsSlice {
fn srcs_as_slice(&self) -> &[Src];
fn srcs_as_mut_slice(&mut self) -> &mut [Src];
fn src_types(&self) -> SrcTypeList;
fn src_types(&self) -> SrcTypeList {
self.attrs()
}
fn src_idx(&self, src: &Src) -> usize {
let r = self.srcs_as_slice().as_ptr_range();
@@ -1449,6 +1465,8 @@ pub trait SrcsAsSlice {
}
}
impl<T: AsSlice<Src, Attr = SrcType>> SrcsAsSlice for T {}
fn all_dsts_uniform(dsts: &[Dst]) -> bool {
let mut uniform = None;
for dst in dsts {
@@ -1481,12 +1499,20 @@ impl DstType {
const DEFAULT: DstType = DstType::Vec;
}
pub type DstTypeList = TypeList<DstType>;
pub type DstTypeList = AttrList<DstType>;
pub trait DstsAsSlice {
fn dsts_as_slice(&self) -> &[Dst];
fn dsts_as_mut_slice(&mut self) -> &mut [Dst];
fn dst_types(&self) -> DstTypeList;
pub trait DstsAsSlice: AsSlice<Dst, Attr = DstType> {
fn dsts_as_slice(&self) -> &[Dst] {
self.as_slice()
}
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
self.as_mut_slice()
}
fn dst_types(&self) -> DstTypeList {
self.attrs()
}
fn dst_idx(&self, dst: &Dst) -> usize {
let r = self.dsts_as_slice().as_ptr_range();
@@ -1495,6 +1521,8 @@ pub trait DstsAsSlice {
}
}
impl<T: AsSlice<Dst, Attr = DstType>> DstsAsSlice for T {}
pub trait IsUniform {
fn is_uniform(&self) -> bool;
}
@@ -3993,16 +4021,18 @@ pub struct OpF2F {
pub integer_rnd: bool,
}
impl SrcsAsSlice for OpF2F {
fn srcs_as_slice(&self) -> &[Src] {
impl AsSlice<Src> for OpF2F {
type Attr = SrcType;
fn as_slice(&self) -> &[Src] {
std::slice::from_ref(&self.src)
}
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
fn as_mut_slice(&mut self) -> &mut [Src] {
std::slice::from_mut(&mut self.src)
}
fn src_types(&self) -> SrcTypeList {
fn attrs(&self) -> SrcTypeList {
let src_type = match self.src_type {
FloatType::F16 => SrcType::F16,
FloatType::F32 => SrcType::F32,
@@ -4012,16 +4042,18 @@ impl SrcsAsSlice for OpF2F {
}
}
impl DstsAsSlice for OpF2F {
fn dsts_as_slice(&self) -> &[Dst] {
impl AsSlice<Dst> for OpF2F {
type Attr = DstType;
fn as_slice(&self) -> &[Dst] {
std::slice::from_ref(&self.dst)
}
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
fn as_mut_slice(&mut self) -> &mut [Dst] {
std::slice::from_mut(&mut self.dst)
}
fn dst_types(&self) -> DstTypeList {
fn attrs(&self) -> DstTypeList {
let dst_type = match self.dst_type {
FloatType::F16 => DstType::F16,
FloatType::F32 => DstType::F32,
@@ -4063,16 +4095,18 @@ pub struct OpF2I {
pub ftz: bool,
}
impl SrcsAsSlice for OpF2I {
fn srcs_as_slice(&self) -> &[Src] {
impl AsSlice<Src> for OpF2I {
type Attr = SrcType;
fn as_slice(&self) -> &[Src] {
std::slice::from_ref(&self.src)
}
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
fn as_mut_slice(&mut self) -> &mut [Src] {
std::slice::from_mut(&mut self.src)
}
fn src_types(&self) -> SrcTypeList {
fn attrs(&self) -> SrcTypeList {
let src_type = match self.src_type {
FloatType::F16 => SrcType::F16,
FloatType::F32 => SrcType::F32,
@@ -4104,16 +4138,18 @@ pub struct OpI2F {
pub rnd_mode: FRndMode,
}
impl SrcsAsSlice for OpI2F {
fn srcs_as_slice(&self) -> &[Src] {
impl AsSlice<Src> for OpI2F {
type Attr = SrcType;
fn as_slice(&self) -> &[Src] {
std::slice::from_ref(&self.src)
}
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
fn as_mut_slice(&mut self) -> &mut [Src] {
std::slice::from_mut(&mut self.src)
}
fn src_types(&self) -> SrcTypeList {
fn attrs(&self) -> SrcTypeList {
if self.src_type.bits() <= 32 {
SrcTypeList::Uniform(SrcType::ALU)
} else {
@@ -4122,16 +4158,18 @@ impl SrcsAsSlice for OpI2F {
}
}
impl DstsAsSlice for OpI2F {
fn dsts_as_slice(&self) -> &[Dst] {
impl AsSlice<Dst> for OpI2F {
type Attr = DstType;
fn as_slice(&self) -> &[Dst] {
std::slice::from_ref(&self.dst)
}
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
fn as_mut_slice(&mut self) -> &mut [Dst] {
std::slice::from_mut(&mut self.dst)
}
fn dst_types(&self) -> DstTypeList {
fn attrs(&self) -> DstTypeList {
let dst_type = match self.dst_type {
FloatType::F16 => DstType::F16,
FloatType::F32 => DstType::F32,
@@ -4202,16 +4240,18 @@ pub struct OpFRnd {
pub ftz: bool,
}
impl SrcsAsSlice for OpFRnd {
fn srcs_as_slice(&self) -> &[Src] {
impl AsSlice<Src> for OpFRnd {
type Attr = SrcType;
fn as_slice(&self) -> &[Src] {
std::slice::from_ref(&self.src)
}
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
fn as_mut_slice(&mut self) -> &mut [Src] {
std::slice::from_mut(&mut self.src)
}
fn src_types(&self) -> SrcTypeList {
fn attrs(&self) -> SrcTypeList {
let src_type = match self.src_type {
FloatType::F16 => SrcType::F16,
FloatType::F32 => SrcType::F32,
@@ -5775,16 +5815,18 @@ impl OpPhiSrcs {
}
}
impl SrcsAsSlice for OpPhiSrcs {
fn srcs_as_slice(&self) -> &[Src] {
impl AsSlice<Src> for OpPhiSrcs {
type Attr = SrcType;
fn as_slice(&self) -> &[Src] {
&self.srcs.b
}
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
fn as_mut_slice(&mut self) -> &mut [Src] {
&mut self.srcs.b
}
fn src_types(&self) -> SrcTypeList {
fn attrs(&self) -> SrcTypeList {
SrcTypeList::Uniform(SrcType::GPR)
}
}
@@ -5821,16 +5863,18 @@ impl OpPhiDsts {
}
}
impl DstsAsSlice for OpPhiDsts {
fn dsts_as_slice(&self) -> &[Dst] {
impl AsSlice<Dst> for OpPhiDsts {
type Attr = DstType;
fn as_slice(&self) -> &[Dst] {
&self.dsts.b
}
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
fn as_mut_slice(&mut self) -> &mut [Dst] {
&mut self.dsts.b
}
fn dst_types(&self) -> DstTypeList {
fn attrs(&self) -> DstTypeList {
DstTypeList::Uniform(DstType::Vec)
}
}
@@ -5936,30 +5980,34 @@ impl OpParCopy {
}
}
impl SrcsAsSlice for OpParCopy {
fn srcs_as_slice(&self) -> &[Src] {
impl AsSlice<Src> for OpParCopy {
type Attr = SrcType;
fn as_slice(&self) -> &[Src] {
&self.dsts_srcs.b
}
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
fn as_mut_slice(&mut self) -> &mut [Src] {
&mut self.dsts_srcs.b
}
fn src_types(&self) -> SrcTypeList {
fn attrs(&self) -> SrcTypeList {
SrcTypeList::Uniform(SrcType::GPR)
}
}
impl DstsAsSlice for OpParCopy {
fn dsts_as_slice(&self) -> &[Dst] {
impl AsSlice<Dst> for OpParCopy {
type Attr = DstType;
fn as_slice(&self) -> &[Dst] {
&self.dsts_srcs.a
}
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
fn as_mut_slice(&mut self) -> &mut [Dst] {
&mut self.dsts_srcs.a
}
fn dst_types(&self) -> DstTypeList {
fn attrs(&self) -> DstTypeList {
DstTypeList::Uniform(DstType::Vec)
}
}
@@ -5988,16 +6036,18 @@ pub struct OpRegOut {
pub srcs: Vec<Src>,
}
impl SrcsAsSlice for OpRegOut {
fn srcs_as_slice(&self) -> &[Src] {
impl AsSlice<Src> for OpRegOut {
type Attr = SrcType;
fn as_slice(&self) -> &[Src] {
&self.srcs
}
fn srcs_as_mut_slice(&mut self) -> &mut [Src] {
fn as_mut_slice(&mut self) -> &mut [Src] {
&mut self.srcs
}
fn src_types(&self) -> SrcTypeList {
fn attrs(&self) -> SrcTypeList {
SrcTypeList::Uniform(SrcType::GPR)
}
}

View File

@@ -26,10 +26,10 @@ fn expr_as_usize(expr: &syn::Expr) -> usize {
.expect("Failed to parse integer literal")
}
fn count_type(ty: &Type, search_type: &str) -> usize {
fn count_type(ty: &Type, slice_type: &str) -> usize {
match ty {
syn::Type::Array(a) => {
let elems = count_type(a.elem.as_ref(), search_type);
let elems = count_type(a.elem.as_ref(), slice_type);
if elems > 0 {
elems * expr_as_usize(&a.len)
} else {
@@ -37,7 +37,7 @@ fn count_type(ty: &Type, search_type: &str) -> usize {
}
}
syn::Type::Path(p) => {
if p.qself.is_none() && p.path.is_ident(search_type) {
if p.qself.is_none() && p.path.is_ident(slice_type) {
1
} else {
0
@@ -47,10 +47,10 @@ fn count_type(ty: &Type, search_type: &str) -> usize {
}
}
fn get_type_attr(field: &Field, ty_attr: &str) -> Option<String> {
fn get_attr(field: &Field, attr_name: &str) -> Option<String> {
for attr in &field.attrs {
if let Meta::List(ml) = &attr.meta {
if ml.path.is_ident(ty_attr) {
if ml.path.is_ident(attr_name) {
return Some(format!("{}", ml.tokens));
}
}
@@ -60,25 +60,14 @@ fn get_type_attr(field: &Field, ty_attr: &str) -> Option<String> {
fn derive_as_slice(
input: TokenStream,
trait_name: &str,
func_prefix: &str,
search_type: &str,
slice_type: &str,
attr_name: &str,
attr_type: &str,
) -> TokenStream {
let DeriveInput {
attrs, ident, data, ..
} = parse_macro_input!(input);
let trait_name = Ident::new(trait_name, Span::call_site());
let elem_type = Ident::new(search_type, Span::call_site());
let as_slice =
Ident::new(&format!("{func_prefix}s_as_slice"), Span::call_site());
let as_mut_slice =
Ident::new(&format!("{func_prefix}s_as_mut_slice"), Span::call_site());
let types_fn =
Ident::new(&format!("{func_prefix}_types"), Span::call_site());
let ty_attr = format!("{func_prefix}_type");
let ty_type = Ident::new(&format!("{search_type}Type"), Span::call_site());
match data {
Data::Struct(s) => {
let mut has_repr_c = false;
@@ -99,35 +88,37 @@ fn derive_as_slice(
let mut first = None;
let mut count = 0_usize;
let mut found_last = false;
let mut types = TokenStream2::new();
let mut attrs = TokenStream2::new();
if let Fields::Named(named) = s.fields {
for f in named.named {
let ty_count = count_type(&f.ty, search_type);
let ty = get_type_attr(&f, &ty_attr);
let f_count = count_type(&f.ty, slice_type);
let f_attr = get_attr(&f, &attr_name);
if ty_count > 0 {
if f_count > 0 {
assert!(
!found_last,
"All fields of type {search_type} must be consecutive",
"All fields of type {slice_type} must be consecutive",
);
let ty = if let Some(s) = ty {
let attr_type =
Ident::new(attr_type, Span::call_site());
let f_attr = if let Some(s) = f_attr {
let s = syn::parse_str::<Ident>(&s).unwrap();
quote! { #ty_type::#s, }
quote! { #attr_type::#s, }
} else {
quote! { #ty_type::DEFAULT, }
quote! { #attr_type::DEFAULT, }
};
first.get_or_insert(f.ident);
for _ in 0..ty_count {
types.extend(ty.clone());
for _ in 0..f_count {
attrs.extend(f_attr.clone());
}
count += ty_count;
count += f_count;
} else {
assert!(
ty.is_none(),
"{ty_attr} attribute is only allowed on {search_type}"
f_attr.is_none(),
"{attr_name} attribute is only allowed on {slice_type}"
);
if !first.is_none() {
found_last = true;
@@ -138,42 +129,49 @@ fn derive_as_slice(
panic!("Fields are not named");
}
if let Some(name) = first {
let slice_type = Ident::new(slice_type, Span::call_site());
let attr_type = Ident::new(attr_type, Span::call_site());
if let Some(first) = first {
quote! {
impl #trait_name for #ident {
fn #as_slice(&self) -> &[#elem_type] {
impl AsSlice<#slice_type> for #ident {
type Attr = #attr_type;
fn as_slice(&self) -> &[#slice_type] {
unsafe {
let first = &self.#name as *const #elem_type;
let first = &self.#first as *const #slice_type;
std::slice::from_raw_parts(first, #count)
}
}
fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
fn as_mut_slice(&mut self) -> &mut [#slice_type] {
unsafe {
let first = &mut self.#name as *mut #elem_type;
let first =
&mut self.#first as *mut #slice_type;
std::slice::from_raw_parts_mut(first, #count)
}
}
fn #types_fn(&self) -> TypeList<#ty_type> {
static TYPES: [#ty_type; #count] = [#types];
TypeList::Array(&TYPES)
fn attrs(&self) -> AttrList<Self::Attr> {
static ATTRS: [#attr_type; #count] = [#attrs];
AttrList::Array(&ATTRS)
}
}
}
} else {
quote! {
impl #trait_name for #ident {
fn #as_slice(&self) -> &[#elem_type] {
impl AsSlice<#slice_type> for #ident {
type Attr = #attr_type;
fn as_slice(&self) -> &[#slice_type] {
&[]
}
fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
fn as_mut_slice(&mut self) -> &mut [#slice_type] {
&mut []
}
fn #types_fn(&self) -> TypeList<#ty_type> {
TypeList::Uniform(#ty_type::DEFAULT)
fn attrs(&self) -> AttrList<Self::Attr> {
AttrList::Uniform(#attr_type::DEFAULT)
}
}
}
@@ -184,33 +182,37 @@ fn derive_as_slice(
let mut as_slice_cases = TokenStream2::new();
let mut as_mut_slice_cases = TokenStream2::new();
let mut types_cases = TokenStream2::new();
let slice_type = Ident::new(slice_type, Span::call_site());
let attr_type = Ident::new(attr_type, Span::call_site());
for v in e.variants {
let case = v.ident;
as_slice_cases.extend(quote! {
#ident::#case(x) => x.#as_slice(),
#ident::#case(x) => AsSlice::<#slice_type>::as_slice(x),
});
as_mut_slice_cases.extend(quote! {
#ident::#case(x) => x.#as_mut_slice(),
#ident::#case(x) => AsSlice::<#slice_type>::as_mut_slice(x),
});
types_cases.extend(quote! {
#ident::#case(x) => x.#types_fn(),
#ident::#case(x) => AsSlice::<#slice_type>::attrs(x),
});
}
quote! {
impl #trait_name for #ident {
fn #as_slice(&self) -> &[#elem_type] {
impl AsSlice<#slice_type> for #ident {
type Attr = #attr_type;
fn as_slice(&self) -> &[#slice_type] {
match self {
#as_slice_cases
}
}
fn #as_mut_slice(&mut self) -> &mut [#elem_type] {
fn as_mut_slice(&mut self) -> &mut [#slice_type] {
match self {
#as_mut_slice_cases
}
}
fn #types_fn(&self) -> TypeList<#ty_type> {
fn attrs(&self) -> AttrList<Self::Attr> {
match self {
#types_cases
}
@@ -225,12 +227,12 @@ fn derive_as_slice(
#[proc_macro_derive(SrcsAsSlice, attributes(src_type))]
pub fn derive_srcs_as_slice(input: TokenStream) -> TokenStream {
derive_as_slice(input, "SrcsAsSlice", "src", "Src")
derive_as_slice(input, "Src", "src_type", "SrcType")
}
#[proc_macro_derive(DstsAsSlice, attributes(dst_type))]
pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream {
derive_as_slice(input, "DstsAsSlice", "dst", "Dst")
derive_as_slice(input, "Dst", "dst_type", "DstType")
}
#[proc_macro_derive(DisplayOp)]