From 369a3c0dca82e12f104b4a3c1d2869aa4e764b5b Mon Sep 17 00:00:00 2001 From: Natalie Vock Date: Mon, 17 Feb 2025 18:42:49 +0100 Subject: [PATCH] aco/spill: Handle calls Part-of: --- src/amd/compiler/aco_spill.cpp | 93 ++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/src/amd/compiler/aco_spill.cpp b/src/amd/compiler/aco_spill.cpp index 0f9e06bc615..89bbca03326 100644 --- a/src/amd/compiler/aco_spill.cpp +++ b/src/amd/compiler/aco_spill.cpp @@ -89,16 +89,18 @@ struct spill_ctx { unsigned vgpr_spill_slots; Temp scratch_rsrc; + unsigned extra_vgprs; + unsigned resume_idx; - spill_ctx(const RegisterDemand target_pressure_, Program* program_) + spill_ctx(const RegisterDemand target_pressure_, Program* program_, unsigned extra_vgprs_) : target_pressure(target_pressure_), program(program_), memory(), renames(program->blocks.size(), aco::map(memory)), spills_entry(program->blocks.size(), aco::unordered_map(memory)), spills_exit(program->blocks.size(), aco::unordered_map(memory)), processed(program->blocks.size(), false), ssa_infos(program->peekAllocationId()), remat(memory), wave_size(program->wave_size), sgpr_spill_slots(0), vgpr_spill_slots(0), - resume_idx(0) + extra_vgprs(extra_vgprs_), resume_idx(0) {} void add_affinity(uint32_t first, uint32_t second) @@ -322,9 +324,13 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) /* check how many live-through variables should be spilled */ RegisterDemand reg_pressure = block->live_in_demand; RegisterDemand loop_demand = reg_pressure; + RegisterDemand loop_call_spills = RegisterDemand(); unsigned i = block_idx; - while (ctx.program->blocks[i].loop_nest_depth >= block->loop_nest_depth) - loop_demand.update(ctx.program->blocks[i++].register_demand); + while (ctx.program->blocks[i].loop_nest_depth >= block->loop_nest_depth) { + loop_demand.update(ctx.program->blocks[i].register_demand); + loop_call_spills.update(ctx.program->blocks[i].call_spills); + ++i; + } for (auto spilled : ctx.spills_exit[block_idx - 1]) { /* variable is not live at loop entry: probably a phi operand */ @@ -352,12 +358,15 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) /* select more live-through variables and constants */ RegType type = RegType::vgpr; - while (loop_demand.exceeds(ctx.target_pressure)) { + while (loop_demand.exceeds(ctx.target_pressure) || + loop_call_spills.exceeds(spilled_registers)) { /* if VGPR demand is low enough, select SGPRs */ - if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr) + if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr && + loop_call_spills.vgpr <= spilled_registers.vgpr) type = RegType::sgpr; /* if SGPR demand is low enough, break */ - if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr) + if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr && + loop_call_spills.sgpr <= spilled_registers.sgpr) break; float score = 0.0; @@ -932,12 +941,28 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s } /* check if register demand is low enough during and after the current instruction */ - if (block->register_demand.exceeds(ctx.target_pressure)) { + if (block->register_demand.exceeds(ctx.target_pressure) || instr->isCall()) { RegisterDemand new_demand = instr->register_demand; std::optional live_changes; /* if reg pressure is too high, spill variable with furthest next use */ - while ((new_demand - spilled_registers).exceeds(ctx.target_pressure)) { + while (true) { + bool needs_spill = (new_demand - spilled_registers).exceeds(ctx.target_pressure); + if (instr->isCall()) { + RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit; + + /* Exclude the linear VGPRs created for spilling SGPRs from the limit, + * if they are placed in clobbered register ranges (i.e. the preserved limit + * can't fit all of them). The preserved spiller will take care of those. */ + call_preserved_limit.vgpr = + MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0); + + needs_spill |= (instr->call().caller_preserved_demand - spilled_registers) + .exceeds(call_preserved_limit); + } + if (!needs_spill) + break; + float score = 0.0; Temp to_spill = Temp(); bool spill_is_operand = false; @@ -947,7 +972,15 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s unsigned avoid_respill = 0; RegType type = RegType::sgpr; - if (new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr) + bool spill_vgpr = new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr; + if (instr->isCall()) { + RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit; + call_preserved_limit.vgpr = + MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0); + spill_vgpr |= instr->call().caller_preserved_demand.vgpr - spilled_registers.vgpr > + call_preserved_limit.vgpr; + } + if (spill_vgpr) type = RegType::vgpr; for (unsigned t : ctx.program->live.live_in[block_idx]) { @@ -970,19 +1003,32 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s for (auto& op : instr->operands) { if (!op.isTemp() || op.getTemp() != var) continue; - /* Spilling vector operands causes us to emit a split_vector, increasing live - * state temporarily. - */ - if (op.isLateKill() || op.isKill() || op.size() > 1) { + if (op.isLateKill() || op.isKill()) { can_spill = false; break; } + /* Spilling vector operands causes us to emit a split_vector, increasing live + * state temporarily. This is ok when we have enough register headroom (when + * spilling calls), but causes spilling to fail otherwise. + */ + if (op.size() > 1) { + RegisterDemand before = new_demand - spilled_registers; + before -= get_temp_registers(instr.get()) - get_live_changes(instr.get()); + if ((before + op.getTemp()).exceeds(ctx.target_pressure)) { + can_spill = false; + break; + } + } if (!live_changes) live_changes = get_temp_reg_changes(instr.get()); /* Don't spill operands if killing operands won't help with register pressure */ - if (!op.isClobbered() && RegisterDemand(op.getTemp()).exceeds(*live_changes)) { + if (!instr->isCall() && !op.isClobbered() && RegisterDemand(op.getTemp()).exceeds(*live_changes)) { + can_spill = false; + break; + } + if (instr->isCall() && !op.isClobbered()) { can_spill = false; break; } @@ -1653,20 +1699,23 @@ spill(Program* program) uint16_t extra_sgprs = 0; /* calculate extra VGPRs required for spilling SGPRs */ - if (demand.sgpr > limit.sgpr) { - unsigned sgpr_spills = demand.sgpr - limit.sgpr; + unsigned sgpr_spills = demand.sgpr - std::min((uint16_t)demand.sgpr, (uint16_t)limit.sgpr); + sgpr_spills += program->max_call_spills.sgpr; + + if (sgpr_spills) extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1; - } /* add extra SGPRs required for spilling VGPRs */ - if (demand.vgpr + extra_vgprs > limit.vgpr) { + if (demand.vgpr + extra_vgprs > limit.vgpr || program->max_call_spills.vgpr) { if (program->gfx_level >= GFX9) extra_sgprs = program->stack_ptr.id() ? 2 : 1; /* SADDR + scc for stack pointer additions */ else extra_sgprs = 5; /* scratch_resource (s4) + scratch_offset (s1) */ - if (demand.sgpr + extra_sgprs > limit.sgpr) { + if (demand.sgpr + extra_sgprs > limit.sgpr || program->max_call_spills.sgpr) { /* re-calculate in case something has changed */ - unsigned sgpr_spills = demand.sgpr + extra_sgprs - limit.sgpr; + sgpr_spills = program->max_call_spills.sgpr; + if (demand.sgpr + extra_sgprs > limit.sgpr) + sgpr_spills += demand.sgpr + extra_sgprs - limit.sgpr; extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1; } } @@ -1674,7 +1723,7 @@ spill(Program* program) const RegisterDemand target(limit.vgpr - extra_vgprs, limit.sgpr - extra_sgprs); /* initialize ctx */ - spill_ctx ctx(target, program); + spill_ctx ctx(target, program, extra_vgprs); gather_ssa_use_info(ctx); get_rematerialize_info(ctx);