diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 114de22b508..d392b57e527 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -467,7 +467,11 @@ can_apply_sgprs(opt_ctx& ctx, aco_ptr& instr) instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 && instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 && instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 && - instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4; + instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4 && + instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_fp8 && + instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_bf8 && + instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf8_fp8 && + instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf8_bf8; } /* only covers special cases */ @@ -529,6 +533,10 @@ alu_can_accept_constant(const aco_ptr& instr, unsigned operand) case aco_opcode::v_dot2_bf16_bf16: /* TODO */ case aco_opcode::v_wmma_f32_16x16x16_f16: case aco_opcode::v_wmma_f32_16x16x16_bf16: + case aco_opcode::v_wmma_f32_16x16x16_fp8_fp8: + case aco_opcode::v_wmma_f32_16x16x16_fp8_bf8: + case aco_opcode::v_wmma_f32_16x16x16_bf8_fp8: + case aco_opcode::v_wmma_f32_16x16x16_bf8_bf8: case aco_opcode::v_wmma_f16_16x16x16_f16: case aco_opcode::v_wmma_bf16_16x16x16_bf16: case aco_opcode::v_wmma_i32_16x16x16_iu8: diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp index 53c8e840814..e396741ac1f 100644 --- a/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp +++ b/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp @@ -3762,6 +3762,20 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) neg_lo[0] = type_a == GLSL_TYPE_INT8; neg_lo[1] = type_b == GLSL_TYPE_INT8; break; + case GLSL_TYPE_FLOAT_E4M3FN: + switch (type_b) { + case GLSL_TYPE_FLOAT_E4M3FN: opcode = aco_opcode::v_wmma_f32_16x16x16_fp8_fp8; break; + case GLSL_TYPE_FLOAT_E5M2: opcode = aco_opcode::v_wmma_f32_16x16x16_fp8_bf8; break; + default: unreachable("invalid cmat_muladd_amd type"); + } + break; + case GLSL_TYPE_FLOAT_E5M2: + switch (type_b) { + case GLSL_TYPE_FLOAT_E4M3FN: opcode = aco_opcode::v_wmma_f32_16x16x16_bf8_fp8; break; + case GLSL_TYPE_FLOAT_E5M2: opcode = aco_opcode::v_wmma_f32_16x16x16_bf8_bf8; break; + default: unreachable("invalid cmat_muladd_amd type"); + } + break; } default: unreachable("invalid cmat_muladd_amd type"); }