From b92afdbd2824fc0ca86454fc54e0e1843b45e7c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Sch=C3=BCrmann?= Date: Thu, 7 Nov 2024 23:56:30 +0100 Subject: [PATCH] aco/assembler: constify assembly functions Ensure that instruction formats and special operands are not manipulated during assembly. Part-of: --- src/amd/compiler/aco_assembler.cpp | 107 ++++++++++++++++------------- 1 file changed, 59 insertions(+), 48 deletions(-) diff --git a/src/amd/compiler/aco_assembler.cpp b/src/amd/compiler/aco_assembler.cpp index 7256687198f..8ca3711c8ee 100644 --- a/src/amd/compiler/aco_assembler.cpp +++ b/src/amd/compiler/aco_assembler.cpp @@ -132,7 +132,7 @@ get_gfx12_cpol(const T& instr) } void -emit_sop2_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_sop2_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; @@ -145,11 +145,12 @@ emit_sop2_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_sopk_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_sopk_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - SALU_instruction& sopk = instr->salu(); + const SALU_instruction& sopk = instr->salu(); assert(sopk.imm <= UINT16_MAX); + uint16_t imm = sopk.imm; if (instr->opcode == aco_opcode::s_subvector_loop_begin) { assert(ctx.gfx_level >= GFX10); @@ -161,7 +162,7 @@ emit_sopk_instruction(asm_context& ctx, std::vector& out, Instruction* /* Adjust s_subvector_loop_begin instruction to the address after the end */ out[ctx.subvector_begin_pos] |= (out.size() - ctx.subvector_begin_pos); /* Adjust s_subvector_loop_end instruction to the address after the beginning */ - sopk.imm = (uint16_t)(ctx.subvector_begin_pos - (int)out.size()); + imm = (uint16_t)(ctx.subvector_begin_pos - (int)out.size()); ctx.subvector_begin_pos = -1; } @@ -172,12 +173,12 @@ emit_sopk_instruction(asm_context& ctx, std::vector& out, Instruction* : !instr->operands.empty() && instr->operands[0].physReg() <= 127 ? reg(ctx, instr->operands[0]) << 16 : 0; - encoding |= sopk.imm; + encoding |= imm; out.push_back(encoding); } void -emit_sop1_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_sop1_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; @@ -189,7 +190,7 @@ emit_sop1_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_sopc_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_sopc_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; @@ -221,10 +222,10 @@ emit_sopp_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_smem_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_smem_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - SMEM_instruction& smem = instr->smem(); + const SMEM_instruction& smem = instr->smem(); bool glc = smem.cache.value & ac_glc; bool dlc = smem.cache.value & ac_dlc; @@ -327,10 +328,10 @@ emit_smem_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_vop2_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vop2_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VALU_instruction& valu = instr->valu(); + const VALU_instruction& valu = instr->valu(); uint32_t encoding = 0; encoding |= opcode << 25; @@ -344,10 +345,10 @@ emit_vop2_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_vop1_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vop1_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VALU_instruction& valu = instr->valu(); + const VALU_instruction& valu = instr->valu(); uint32_t encoding = (0b0111111 << 25); if (!instr->definitions.empty()) { @@ -363,10 +364,10 @@ emit_vop1_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_vopc_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vopc_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VALU_instruction& valu = instr->valu(); + const VALU_instruction& valu = instr->valu(); uint32_t encoding = (0b0111110 << 25); encoding |= opcode << 17; @@ -378,10 +379,10 @@ emit_vopc_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_vintrp_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vintrp_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VINTRP_instruction& interp = instr->vintrp(); + const VINTRP_instruction& interp = instr->vintrp(); uint32_t encoding = 0; if (instr->opcode == aco_opcode::v_interp_p1ll_f16 || @@ -437,10 +438,11 @@ emit_vintrp_instruction(asm_context& ctx, std::vector& out, Instructio } void -emit_vinterp_inreg_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vinterp_inreg_instruction(asm_context& ctx, std::vector& out, + const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VINTERP_inreg_instruction& interp = instr->vinterp_inreg(); + const VINTERP_inreg_instruction& interp = instr->vinterp_inreg(); uint32_t encoding = (0b11001101 << 24); encoding |= reg(ctx, instr->definitions[0], 8); @@ -459,10 +461,10 @@ emit_vinterp_inreg_instruction(asm_context& ctx, std::vector& out, Ins } void -emit_vopd_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vopd_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VOPD_instruction& vopd = instr->vopd(); + const VOPD_instruction& vopd = instr->vopd(); uint32_t encoding = (0b110010 << 26); encoding |= reg(ctx, instr->operands[0]); @@ -483,10 +485,10 @@ emit_vopd_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_ds_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_ds_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - DS_instruction& ds = instr->ds(); + const DS_instruction& ds = instr->ds(); uint32_t encoding = (0b110110 << 26); if (ctx.gfx_level == GFX8 || ctx.gfx_level == GFX9) { @@ -503,7 +505,7 @@ emit_ds_instruction(asm_context& ctx, std::vector& out, Instruction* i if (!instr->definitions.empty()) encoding |= reg(ctx, instr->definitions[0], 8) << 24; for (unsigned i = 0; i < MIN2(instr->operands.size(), 3); i++) { - Operand& op = instr->operands[i]; + const Operand& op = instr->operands[i]; if (op.physReg() != m0 && !op.isUndefined()) encoding |= reg(ctx, op, 8) << (8 * i); } @@ -511,10 +513,10 @@ emit_ds_instruction(asm_context& ctx, std::vector& out, Instruction* i } void -emit_ldsdir_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_ldsdir_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - LDSDIR_instruction& dir = instr->ldsdir(); + const LDSDIR_instruction& dir = instr->ldsdir(); uint32_t encoding = (0b11001110 << 24); encoding |= opcode << 20; @@ -528,10 +530,10 @@ emit_ldsdir_instruction(asm_context& ctx, std::vector& out, Instructio } void -emit_mubuf_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_mubuf_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - MUBUF_instruction& mubuf = instr->mubuf(); + const MUBUF_instruction& mubuf = instr->mubuf(); bool glc = mubuf.cache.value & ac_glc; bool slc = mubuf.cache.value & ac_slc; bool dlc = mubuf.cache.value & ac_dlc; @@ -583,10 +585,10 @@ emit_mubuf_instruction(asm_context& ctx, std::vector& out, Instruction } void -emit_mubuf_instruction_gfx12(asm_context& ctx, std::vector& out, Instruction* instr) +emit_mubuf_instruction_gfx12(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - MUBUF_instruction& mubuf = instr->mubuf(); + const MUBUF_instruction& mubuf = instr->mubuf(); assert(!mubuf.lds); uint32_t encoding = 0b110001 << 26; @@ -620,10 +622,10 @@ emit_mubuf_instruction_gfx12(asm_context& ctx, std::vector& out, Instr } void -emit_mtbuf_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_mtbuf_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - MTBUF_instruction& mtbuf = instr->mtbuf(); + const MTBUF_instruction& mtbuf = instr->mtbuf(); bool glc = mtbuf.cache.value & ac_glc; bool slc = mtbuf.cache.value & ac_slc; bool dlc = mtbuf.cache.value & ac_dlc; @@ -676,10 +678,10 @@ emit_mtbuf_instruction(asm_context& ctx, std::vector& out, Instruction } void -emit_mtbuf_instruction_gfx12(asm_context& ctx, std::vector& out, Instruction* instr) +emit_mtbuf_instruction_gfx12(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - MTBUF_instruction& mtbuf = instr->mtbuf(); + const MTBUF_instruction& mtbuf = instr->mtbuf(); uint32_t img_format = ac_get_tbuffer_format(ctx.gfx_level, mtbuf.dfmt, mtbuf.nfmt); @@ -714,10 +716,10 @@ emit_mtbuf_instruction_gfx12(asm_context& ctx, std::vector& out, Instr } void -emit_mimg_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_mimg_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - MIMG_instruction& mimg = instr->mimg(); + const MIMG_instruction& mimg = instr->mimg(); bool glc = mimg.cache.value & ac_glc; bool slc = mimg.cache.value & ac_slc; bool dlc = mimg.cache.value & ac_dlc; @@ -800,10 +802,10 @@ emit_mimg_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_mimg_instruction_gfx12(asm_context& ctx, std::vector& out, Instruction* instr) +emit_mimg_instruction_gfx12(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - MIMG_instruction& mimg = instr->mimg(); + const MIMG_instruction& mimg = instr->mimg(); bool vsample = !instr->operands[1].isUndefined() || instr->opcode == aco_opcode::image_msaa_load; uint32_t encoding = opcode << 14; @@ -852,10 +854,10 @@ emit_mimg_instruction_gfx12(asm_context& ctx, std::vector& out, Instru } void -emit_flatlike_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_flatlike_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - FLAT_instruction& flat = instr->flatlike(); + const FLAT_instruction& flat = instr->flatlike(); bool glc = flat.cache.value & ac_glc; bool slc = flat.cache.value & ac_slc; bool dlc = flat.cache.value & ac_dlc; @@ -919,10 +921,11 @@ emit_flatlike_instruction(asm_context& ctx, std::vector& out, Instruct } void -emit_flatlike_instruction_gfx12(asm_context& ctx, std::vector& out, Instruction* instr) +emit_flatlike_instruction_gfx12(asm_context& ctx, std::vector& out, + const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - FLAT_instruction& flat = instr->flatlike(); + const FLAT_instruction& flat = instr->flatlike(); assert(!flat.lds); uint32_t encoding = opcode << 14; @@ -957,9 +960,9 @@ emit_flatlike_instruction_gfx12(asm_context& ctx, std::vector& out, In } void -emit_exp_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_exp_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { - Export_instruction& exp = instr->exp(); + const Export_instruction& exp = instr->exp(); uint32_t encoding; if (ctx.gfx_level == GFX8 || ctx.gfx_level == GFX9) { encoding = (0b110001 << 26); @@ -997,6 +1000,9 @@ emit_dpp16_instruction(asm_context& ctx, std::vector& out, Instruction instr->operands[0] = Operand(PhysReg{250}, v1); instr->format = (Format)((uint16_t)instr->format & ~(uint16_t)Format::DPP16); emit_instruction(ctx, out, instr); + instr->format = (Format)((uint16_t)instr->format | (uint16_t)Format::DPP16); + instr->operands[0] = dpp_op; + uint32_t encoding = (0xF & dpp.row_mask) << 28; encoding |= (0xF & dpp.bank_mask) << 24; encoding |= dpp.abs[1] << 23; @@ -1022,6 +1028,9 @@ emit_dpp8_instruction(asm_context& ctx, std::vector& out, Instruction* instr->operands[0] = Operand(PhysReg{233u + dpp.fetch_inactive}, v1); instr->format = (Format)((uint16_t)instr->format & ~(uint16_t)Format::DPP8); emit_instruction(ctx, out, instr); + instr->format = (Format)((uint16_t)instr->format | (uint16_t)Format::DPP8); + instr->operands[0] = dpp_op; + uint32_t encoding = reg(ctx, dpp_op, 8); encoding |= dpp.opsel[0] && !instr->isVOP3() ? 128 : 0; encoding |= dpp.lane_sel << 8; @@ -1029,10 +1038,10 @@ emit_dpp8_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_vop3_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vop3_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VALU_instruction& vop3 = instr->valu(); + const VALU_instruction& vop3 = instr->valu(); if (instr->isVOP2()) { opcode = opcode + 0x100; @@ -1093,10 +1102,10 @@ emit_vop3_instruction(asm_context& ctx, std::vector& out, Instruction* } void -emit_vop3p_instruction(asm_context& ctx, std::vector& out, Instruction* instr) +emit_vop3p_instruction(asm_context& ctx, std::vector& out, const Instruction* instr) { uint32_t opcode = ctx.opcode[(int)instr->opcode]; - VALU_instruction& vop3 = instr->valu(); + const VALU_instruction& vop3 = instr->valu(); uint32_t encoding; if (ctx.gfx_level == GFX9) { @@ -1135,6 +1144,8 @@ emit_sdwa_instruction(asm_context& ctx, std::vector& out, Instruction* instr->operands[0] = Operand(PhysReg{249}, v1); instr->format = (Format)((uint16_t)instr->format & ~(uint16_t)Format::SDWA); emit_instruction(ctx, out, instr); + instr->format = (Format)((uint16_t)instr->format | (uint16_t)Format::SDWA); + instr->operands[0] = sdwa_op; uint32_t encoding = 0;