diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index 846e6ae6f30..cccf41b0313 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -208,12 +208,16 @@ struct DefInfo { PhysRegInterval bounds; uint8_t size; uint8_t stride; + /* Even if stride=4, we might be able to write to the high half instead without preserving the + * low half. In that case, data_stride=2. */ + uint8_t data_stride; RegClass rc; DefInfo(ra_ctx& ctx, aco_ptr& instr, RegClass rc_, int operand) : rc(rc_) { size = rc.size(); stride = get_stride(rc); + data_stride = 0; bounds = get_reg_bounds(ctx, rc); @@ -226,9 +230,8 @@ struct DefInfo { if (info.second > rc.bytes()) { rc = RegClass::get(rc.type(), info.second); size = rc.size(); - /* we might still be able to put the definition in the high half, - * but that's only useful for affinities and this information isn't - * used for them */ + if (info.second > stride) + data_stride = stride; stride = align(stride, info.second); if (!rc.is_subdword()) stride = DIV_ROUND_UP(stride, 4); @@ -249,6 +252,9 @@ struct DefInfo { if (imageGather4D16Bug) bounds.size -= MAX2(rc.bytes() / 4 - ctx.num_linear_vgprs, 0); } + + if (!data_stride) + data_stride = rc.is_subdword() ? stride : (stride * 4); } }; @@ -670,7 +676,6 @@ get_subdword_definition_info(Program* program, const aco_ptr& instr return std::make_pair(4u, 6u); break; } - default: break; } @@ -1379,45 +1384,33 @@ get_reg_impl(ra_ctx& ctx, const RegisterFile& reg_file, bool get_reg_specified(ra_ctx& ctx, const RegisterFile& reg_file, RegClass rc, - aco_ptr& instr, PhysReg reg) + aco_ptr& instr, PhysReg reg, int operand) { /* catch out-of-range registers */ if (reg >= PhysReg{512}) return false; - std::pair sdw_def_info; - if (rc.is_subdword()) - sdw_def_info = get_subdword_definition_info(ctx.program, instr, rc); + DefInfo info(ctx, instr, rc, operand); - if (rc.is_subdword() && reg.byte() % sdw_def_info.first) - return false; - if (!rc.is_subdword() && reg.byte()) + if (reg.reg_b % info.data_stride) return false; - if (rc.type() == RegType::sgpr && reg % get_stride(rc) != 0) - return false; + assert(util_is_power_of_two_nonzero(info.stride)); + reg.reg_b &= ~(info.stride - 1); - PhysRegInterval reg_win = {reg, rc.size()}; - PhysRegInterval bounds = get_reg_bounds(ctx, rc); + PhysRegInterval reg_win = {PhysReg(reg.reg()), info.rc.size()}; PhysRegInterval vcc_win = {vcc, 2}; /* VCC is outside the bounds */ - bool is_vcc = rc.type() == RegType::sgpr && vcc_win.contains(reg_win) && ctx.program->needs_vcc; - bool is_m0 = rc == s1 && reg == m0 && can_write_m0(instr); - if (!bounds.contains(reg_win) && !is_vcc && !is_m0) + bool is_vcc = + info.rc.type() == RegType::sgpr && vcc_win.contains(reg_win) && ctx.program->needs_vcc; + bool is_m0 = info.rc == s1 && reg == m0 && can_write_m0(instr); + if (!info.bounds.contains(reg_win) && !is_vcc && !is_m0) return false; - if (rc.is_subdword()) { - PhysReg test_reg = reg; - if (sdw_def_info.second > rc.bytes()) - test_reg.reg_b &= ~(align(sdw_def_info.first, sdw_def_info.second) - 1); - if (reg_file.test(test_reg, sdw_def_info.second)) - return false; - } else { - if (reg_file.test(reg, rc.bytes())) - return false; - } + if (reg_file.test(reg, info.rc.bytes())) + return false; - adjust_max_used_regs(ctx, rc, reg_win.lo()); + adjust_max_used_regs(ctx, info.rc, reg_win.lo()); return true; } @@ -1562,7 +1555,8 @@ is_mimg_vaddr_intact(ra_ctx& ctx, const RegisterFile& reg_file, Instruction* ins } std::optional -get_reg_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, aco_ptr& instr) +get_reg_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, aco_ptr& instr, + int operand) { Instruction* vec = ctx.vectors[temp.id()]; unsigned first_operand = vec->format == Format::MIMG ? 3 : 0; @@ -1587,7 +1581,7 @@ get_reg_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, aco_ptrreg_b += our_offset; /* make sure to only use byte offset if the instruction supports it */ - if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, *reg)) + if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, *reg, operand)) return reg; } } @@ -1740,7 +1734,7 @@ get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, if (affinity.assigned) { PhysReg reg = affinity.reg; reg.reg_b -= offset; - if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, reg)) + if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, reg, operand_index)) return reg; } } @@ -1751,23 +1745,23 @@ get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, if (ctx.assignments[temp.id()].affinity) { assignment& affinity = ctx.assignments[ctx.assignments[temp.id()].affinity]; if (affinity.assigned) { - if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, affinity.reg)) + if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, affinity.reg, operand_index)) return affinity.reg; } } if (ctx.assignments[temp.id()].vcc) { - if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, vcc)) + if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, vcc, operand_index)) return vcc; } if (ctx.assignments[temp.id()].m0) { - if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, m0)) + if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, m0, operand_index)) return m0; } std::optional res; if (ctx.vectors.find(temp.id()) != ctx.vectors.end()) { - res = get_reg_vector(ctx, reg_file, temp, instr); + res = get_reg_vector(ctx, reg_file, temp, instr, operand_index); if (res) return *res; } @@ -1776,7 +1770,8 @@ get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, for (const Operand& op : instr->operands) { if (op.isTemp() && op.isFirstKillBeforeDef() && op.regClass() == temp.regClass()) { assert(op.isFixed()); - if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, op.physReg())) + if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, op.physReg(), + operand_index)) return op.physReg(); } } @@ -2242,7 +2237,7 @@ get_regs_for_phis(ra_ctx& ctx, Block& block, RegisterFile& register_file, if (!all_same) continue; - if (!get_reg_specified(ctx, register_file, definition.regClass(), phi, reg)) + if (!get_reg_specified(ctx, register_file, definition.regClass(), phi, reg, -1)) continue; definition.setFixed(reg); @@ -2261,7 +2256,7 @@ get_regs_for_phis(ra_ctx& ctx, Block& block, RegisterFile& register_file, ctx.assignments[ctx.assignments[definition.tempId()].affinity].assigned) { assignment& affinity = ctx.assignments[ctx.assignments[definition.tempId()].affinity]; assert(affinity.rc == definition.regClass()); - if (get_reg_specified(ctx, register_file, definition.regClass(), phi, affinity.reg)) { + if (get_reg_specified(ctx, register_file, definition.regClass(), phi, affinity.reg, -1)) { definition.setFixed(affinity.reg); register_file.fill(definition); ctx.assignments[definition.tempId()].set(definition); @@ -2276,7 +2271,7 @@ get_regs_for_phis(ra_ctx& ctx, Block& block, RegisterFile& register_file, continue; PhysReg reg = op.physReg(); - if (get_reg_specified(ctx, register_file, definition.regClass(), phi, reg)) { + if (get_reg_specified(ctx, register_file, definition.regClass(), phi, reg, -1)) { definition.setFixed(reg); register_file.fill(definition); ctx.assignments[definition.tempId()].set(definition); @@ -3147,18 +3142,18 @@ register_allocation(Program* program, ra_test_policy policy) RegClass rc = definition->regClass(); for (unsigned j = 0; j < i; j++) reg.reg_b += instr->definitions[j].bytes(); - if (get_reg_specified(ctx, register_file, rc, instr, reg)) { + if (get_reg_specified(ctx, register_file, rc, instr, reg, -1)) { definition->setFixed(reg); } else if (i == 0) { RegClass vec_rc = RegClass::get(rc.type(), instr->operands[0].bytes()); DefInfo info(ctx, ctx.pseudo_dummy, vec_rc, -1); std::optional res = get_reg_simple(ctx, register_file, info); - if (res && get_reg_specified(ctx, register_file, rc, instr, *res)) + if (res && get_reg_specified(ctx, register_file, rc, instr, *res, -1)) definition->setFixed(*res); } else if (instr->definitions[i - 1].isFixed()) { reg = instr->definitions[i - 1].physReg(); reg.reg_b += instr->definitions[i - 1].bytes(); - if (get_reg_specified(ctx, register_file, rc, instr, reg)) + if (get_reg_specified(ctx, register_file, rc, instr, reg, -1)) definition->setFixed(reg); } } else if (instr->opcode == aco_opcode::p_parallelcopy) { @@ -3170,7 +3165,7 @@ register_allocation(Program* program, ra_test_policy policy) } else if (instr->opcode == aco_opcode::p_extract_vector) { PhysReg reg = instr->operands[0].physReg(); reg.reg_b += definition->bytes() * instr->operands[1].constantValue(); - if (get_reg_specified(ctx, register_file, definition->regClass(), instr, reg)) + if (get_reg_specified(ctx, register_file, definition->regClass(), instr, reg, -1)) definition->setFixed(reg); } else if (instr->opcode == aco_opcode::p_create_vector) { PhysReg reg = get_reg_create_vector(ctx, register_file, definition->getTemp(),