aco/spill: Handle calls
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38281>
This commit is contained in:
@@ -89,16 +89,18 @@ struct spill_ctx {
|
||||
unsigned vgpr_spill_slots;
|
||||
Temp scratch_rsrc;
|
||||
|
||||
unsigned extra_vgprs;
|
||||
|
||||
unsigned resume_idx;
|
||||
|
||||
spill_ctx(const RegisterDemand target_pressure_, Program* program_)
|
||||
spill_ctx(const RegisterDemand target_pressure_, Program* program_, unsigned extra_vgprs_)
|
||||
: target_pressure(target_pressure_), program(program_), memory(),
|
||||
renames(program->blocks.size(), aco::map<Temp, Temp>(memory)),
|
||||
spills_entry(program->blocks.size(), aco::unordered_map<Temp, uint32_t>(memory)),
|
||||
spills_exit(program->blocks.size(), aco::unordered_map<Temp, uint32_t>(memory)),
|
||||
processed(program->blocks.size(), false), ssa_infos(program->peekAllocationId()),
|
||||
remat(memory), wave_size(program->wave_size), sgpr_spill_slots(0), vgpr_spill_slots(0),
|
||||
resume_idx(0)
|
||||
extra_vgprs(extra_vgprs_), resume_idx(0)
|
||||
{}
|
||||
|
||||
void add_affinity(uint32_t first, uint32_t second)
|
||||
@@ -322,9 +324,13 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
|
||||
/* check how many live-through variables should be spilled */
|
||||
RegisterDemand reg_pressure = block->live_in_demand;
|
||||
RegisterDemand loop_demand = reg_pressure;
|
||||
RegisterDemand loop_call_spills = RegisterDemand();
|
||||
unsigned i = block_idx;
|
||||
while (ctx.program->blocks[i].loop_nest_depth >= block->loop_nest_depth)
|
||||
loop_demand.update(ctx.program->blocks[i++].register_demand);
|
||||
while (ctx.program->blocks[i].loop_nest_depth >= block->loop_nest_depth) {
|
||||
loop_demand.update(ctx.program->blocks[i].register_demand);
|
||||
loop_call_spills.update(ctx.program->blocks[i].call_spills);
|
||||
++i;
|
||||
}
|
||||
|
||||
for (auto spilled : ctx.spills_exit[block_idx - 1]) {
|
||||
/* variable is not live at loop entry: probably a phi operand */
|
||||
@@ -352,12 +358,15 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
|
||||
|
||||
/* select more live-through variables and constants */
|
||||
RegType type = RegType::vgpr;
|
||||
while (loop_demand.exceeds(ctx.target_pressure)) {
|
||||
while (loop_demand.exceeds(ctx.target_pressure) ||
|
||||
loop_call_spills.exceeds(spilled_registers)) {
|
||||
/* if VGPR demand is low enough, select SGPRs */
|
||||
if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr)
|
||||
if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr &&
|
||||
loop_call_spills.vgpr <= spilled_registers.vgpr)
|
||||
type = RegType::sgpr;
|
||||
/* if SGPR demand is low enough, break */
|
||||
if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr)
|
||||
if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr &&
|
||||
loop_call_spills.sgpr <= spilled_registers.sgpr)
|
||||
break;
|
||||
|
||||
float score = 0.0;
|
||||
@@ -932,12 +941,28 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
|
||||
}
|
||||
|
||||
/* check if register demand is low enough during and after the current instruction */
|
||||
if (block->register_demand.exceeds(ctx.target_pressure)) {
|
||||
if (block->register_demand.exceeds(ctx.target_pressure) || instr->isCall()) {
|
||||
RegisterDemand new_demand = instr->register_demand;
|
||||
std::optional<RegisterDemand> live_changes;
|
||||
|
||||
/* if reg pressure is too high, spill variable with furthest next use */
|
||||
while ((new_demand - spilled_registers).exceeds(ctx.target_pressure)) {
|
||||
while (true) {
|
||||
bool needs_spill = (new_demand - spilled_registers).exceeds(ctx.target_pressure);
|
||||
if (instr->isCall()) {
|
||||
RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit;
|
||||
|
||||
/* Exclude the linear VGPRs created for spilling SGPRs from the limit,
|
||||
* if they are placed in clobbered register ranges (i.e. the preserved limit
|
||||
* can't fit all of them). The preserved spiller will take care of those. */
|
||||
call_preserved_limit.vgpr =
|
||||
MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0);
|
||||
|
||||
needs_spill |= (instr->call().caller_preserved_demand - spilled_registers)
|
||||
.exceeds(call_preserved_limit);
|
||||
}
|
||||
if (!needs_spill)
|
||||
break;
|
||||
|
||||
float score = 0.0;
|
||||
Temp to_spill = Temp();
|
||||
bool spill_is_operand = false;
|
||||
@@ -947,7 +972,15 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
|
||||
unsigned avoid_respill = 0;
|
||||
|
||||
RegType type = RegType::sgpr;
|
||||
if (new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr)
|
||||
bool spill_vgpr = new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr;
|
||||
if (instr->isCall()) {
|
||||
RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit;
|
||||
call_preserved_limit.vgpr =
|
||||
MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0);
|
||||
spill_vgpr |= instr->call().caller_preserved_demand.vgpr - spilled_registers.vgpr >
|
||||
call_preserved_limit.vgpr;
|
||||
}
|
||||
if (spill_vgpr)
|
||||
type = RegType::vgpr;
|
||||
|
||||
for (unsigned t : ctx.program->live.live_in[block_idx]) {
|
||||
@@ -970,19 +1003,32 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
|
||||
for (auto& op : instr->operands) {
|
||||
if (!op.isTemp() || op.getTemp() != var)
|
||||
continue;
|
||||
/* Spilling vector operands causes us to emit a split_vector, increasing live
|
||||
* state temporarily.
|
||||
*/
|
||||
if (op.isLateKill() || op.isKill() || op.size() > 1) {
|
||||
if (op.isLateKill() || op.isKill()) {
|
||||
can_spill = false;
|
||||
break;
|
||||
}
|
||||
/* Spilling vector operands causes us to emit a split_vector, increasing live
|
||||
* state temporarily. This is ok when we have enough register headroom (when
|
||||
* spilling calls), but causes spilling to fail otherwise.
|
||||
*/
|
||||
if (op.size() > 1) {
|
||||
RegisterDemand before = new_demand - spilled_registers;
|
||||
before -= get_temp_registers(instr.get()) - get_live_changes(instr.get());
|
||||
if ((before + op.getTemp()).exceeds(ctx.target_pressure)) {
|
||||
can_spill = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!live_changes)
|
||||
live_changes = get_temp_reg_changes(instr.get());
|
||||
|
||||
/* Don't spill operands if killing operands won't help with register pressure */
|
||||
if (!op.isClobbered() && RegisterDemand(op.getTemp()).exceeds(*live_changes)) {
|
||||
if (!instr->isCall() && !op.isClobbered() && RegisterDemand(op.getTemp()).exceeds(*live_changes)) {
|
||||
can_spill = false;
|
||||
break;
|
||||
}
|
||||
if (instr->isCall() && !op.isClobbered()) {
|
||||
can_spill = false;
|
||||
break;
|
||||
}
|
||||
@@ -1653,20 +1699,23 @@ spill(Program* program)
|
||||
uint16_t extra_sgprs = 0;
|
||||
|
||||
/* calculate extra VGPRs required for spilling SGPRs */
|
||||
if (demand.sgpr > limit.sgpr) {
|
||||
unsigned sgpr_spills = demand.sgpr - limit.sgpr;
|
||||
unsigned sgpr_spills = demand.sgpr - std::min((uint16_t)demand.sgpr, (uint16_t)limit.sgpr);
|
||||
sgpr_spills += program->max_call_spills.sgpr;
|
||||
|
||||
if (sgpr_spills)
|
||||
extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1;
|
||||
}
|
||||
/* add extra SGPRs required for spilling VGPRs */
|
||||
if (demand.vgpr + extra_vgprs > limit.vgpr) {
|
||||
if (demand.vgpr + extra_vgprs > limit.vgpr || program->max_call_spills.vgpr) {
|
||||
if (program->gfx_level >= GFX9)
|
||||
extra_sgprs =
|
||||
program->stack_ptr.id() ? 2 : 1; /* SADDR + scc for stack pointer additions */
|
||||
else
|
||||
extra_sgprs = 5; /* scratch_resource (s4) + scratch_offset (s1) */
|
||||
if (demand.sgpr + extra_sgprs > limit.sgpr) {
|
||||
if (demand.sgpr + extra_sgprs > limit.sgpr || program->max_call_spills.sgpr) {
|
||||
/* re-calculate in case something has changed */
|
||||
unsigned sgpr_spills = demand.sgpr + extra_sgprs - limit.sgpr;
|
||||
sgpr_spills = program->max_call_spills.sgpr;
|
||||
if (demand.sgpr + extra_sgprs > limit.sgpr)
|
||||
sgpr_spills += demand.sgpr + extra_sgprs - limit.sgpr;
|
||||
extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1;
|
||||
}
|
||||
}
|
||||
@@ -1674,7 +1723,7 @@ spill(Program* program)
|
||||
const RegisterDemand target(limit.vgpr - extra_vgprs, limit.sgpr - extra_sgprs);
|
||||
|
||||
/* initialize ctx */
|
||||
spill_ctx ctx(target, program);
|
||||
spill_ctx ctx(target, program, extra_vgprs);
|
||||
gather_ssa_use_info(ctx);
|
||||
get_rematerialize_info(ctx);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user