From eb95f7cc0e94c3c8202dcc850b1644ee5a8a7a09 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Tue, 11 Mar 2025 15:43:49 +0000 Subject: [PATCH] aco: support sign extension in apply_load_extract fossil-db (gfx1201): Totals from 10 (0.01% of 79377) affected shaders: Instrs: 28954 -> 28938 (-0.06%) CodeSize: 164552 -> 164472 (-0.05%) Latency: 1249341 -> 1247037 (-0.18%) InvThroughput: 297077 -> 296618 (-0.15%) VALU: 15951 -> 15941 (-0.06%) Signed-off-by: Rhys Perry Reviewed-by: Georg Lehmann Part-of: --- src/amd/compiler/aco_optimizer.cpp | 54 ++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index ecc2b7ea3cb..5fc69b1df41 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -3235,14 +3235,9 @@ apply_load_extract(opt_ctx& ctx, aco_ptr& extract) unsigned extract_idx = extract->operands[1].constantValue(); unsigned bits_extracted = extract->operands[2].constantValue(); - unsigned sign_ext = extract->operands[3].constantValue(); + bool sign_ext = extract->operands[3].constantValue(); unsigned dst_bitsize = extract->definitions[0].bytes() * 8u; - /* TODO: These are doable, but probably don't occur too often. */ - if (extract_idx || sign_ext || dst_bitsize != 32 || - (load->definitions[0].regClass().type() != extract->definitions[0].regClass().type())) - return false; - unsigned bits_loaded = 0; bool can_shrink = false; switch (load->opcode) { @@ -3273,6 +3268,11 @@ apply_load_extract(opt_ctx& ctx, aco_ptr& extract) default: return false; } + /* TODO: These are doable, but probably don't occur too often. */ + if (extract_idx || bits_extracted > bits_loaded || dst_bitsize != 32 || + (load->definitions[0].regClass().type() != extract->definitions[0].regClass().type())) + return false; + /* We can't shrink some loads because that would remove zeroing of the offset/address LSBs. */ if (!can_shrink && bits_extracted < bits_loaded) return false; @@ -3281,40 +3281,60 @@ apply_load_extract(opt_ctx& ctx, aco_ptr& extract) bits_loaded = MIN2(bits_loaded, bits_extracted); /* Change the opcode so it writes the full register. */ + bool is_s_buffer = load->opcode == aco_opcode::s_buffer_load_ubyte || + load->opcode == aco_opcode::s_buffer_load_ushort; if (bits_loaded == 8 && load->isDS()) - load->opcode = aco_opcode::ds_read_u8; + load->opcode = sign_ext ? aco_opcode::ds_read_i8 : aco_opcode::ds_read_u8; else if (bits_loaded == 16 && load->isDS()) - load->opcode = aco_opcode::ds_read_u16; + load->opcode = sign_ext ? aco_opcode::ds_read_i16 : aco_opcode::ds_read_u16; else if (bits_loaded == 8 && load->isMUBUF()) - load->opcode = aco_opcode::buffer_load_ubyte; + load->opcode = sign_ext ? aco_opcode::buffer_load_sbyte : aco_opcode::buffer_load_ubyte; else if (bits_loaded == 16 && load->isMUBUF()) - load->opcode = aco_opcode::buffer_load_ushort; + load->opcode = sign_ext ? aco_opcode::buffer_load_sshort : aco_opcode::buffer_load_ushort; else if (bits_loaded == 8 && load->isFlat()) - load->opcode = aco_opcode::flat_load_ubyte; + load->opcode = sign_ext ? aco_opcode::flat_load_sbyte : aco_opcode::flat_load_ubyte; else if (bits_loaded == 16 && load->isFlat()) - load->opcode = aco_opcode::flat_load_ushort; + load->opcode = sign_ext ? aco_opcode::flat_load_sshort : aco_opcode::flat_load_ushort; else if (bits_loaded == 8 && load->isGlobal()) - load->opcode = aco_opcode::global_load_ubyte; + load->opcode = sign_ext ? aco_opcode::global_load_sbyte : aco_opcode::global_load_ubyte; else if (bits_loaded == 16 && load->isGlobal()) - load->opcode = aco_opcode::global_load_ushort; + load->opcode = sign_ext ? aco_opcode::global_load_sshort : aco_opcode::global_load_ushort; else if (bits_loaded == 8 && load->isScratch()) - load->opcode = aco_opcode::scratch_load_ubyte; + load->opcode = sign_ext ? aco_opcode::scratch_load_sbyte : aco_opcode::scratch_load_ubyte; else if (bits_loaded == 16 && load->isScratch()) - load->opcode = aco_opcode::scratch_load_ushort; - else if (!load->isSMEM()) + load->opcode = sign_ext ? aco_opcode::scratch_load_sshort : aco_opcode::scratch_load_ushort; + else if (bits_loaded == 8 && load->isSMEM() && is_s_buffer) + load->opcode = sign_ext ? aco_opcode::s_buffer_load_sbyte : aco_opcode::s_buffer_load_ubyte; + else if (bits_loaded == 8 && load->isSMEM() && !is_s_buffer) + load->opcode = sign_ext ? aco_opcode::s_load_sbyte : aco_opcode::s_load_ubyte; + else if (bits_loaded == 16 && load->isSMEM() && is_s_buffer) + load->opcode = sign_ext ? aco_opcode::s_buffer_load_sshort : aco_opcode::s_buffer_load_ushort; + else if (bits_loaded == 16 && load->isSMEM() && !is_s_buffer) + load->opcode = sign_ext ? aco_opcode::s_load_sshort : aco_opcode::s_load_ushort; + else unreachable("Forgot to add opcode above."); if (dst_bitsize <= 16 && ctx.program->gfx_level >= GFX9) { switch (load->opcode) { + case aco_opcode::ds_read_i8: load->opcode = aco_opcode::ds_read_i8_d16; break; case aco_opcode::ds_read_u8: load->opcode = aco_opcode::ds_read_u8_d16; break; + case aco_opcode::ds_read_i16: load->opcode = aco_opcode::ds_read_u16_d16; break; case aco_opcode::ds_read_u16: load->opcode = aco_opcode::ds_read_u16_d16; break; + case aco_opcode::buffer_load_sbyte: load->opcode = aco_opcode::buffer_load_sbyte_d16; break; case aco_opcode::buffer_load_ubyte: load->opcode = aco_opcode::buffer_load_ubyte_d16; break; + case aco_opcode::buffer_load_sshort: load->opcode = aco_opcode::buffer_load_short_d16; break; case aco_opcode::buffer_load_ushort: load->opcode = aco_opcode::buffer_load_short_d16; break; + case aco_opcode::flat_load_sbyte: load->opcode = aco_opcode::flat_load_sbyte_d16; break; case aco_opcode::flat_load_ubyte: load->opcode = aco_opcode::flat_load_ubyte_d16; break; + case aco_opcode::flat_load_sshort: load->opcode = aco_opcode::flat_load_short_d16; break; case aco_opcode::flat_load_ushort: load->opcode = aco_opcode::flat_load_short_d16; break; + case aco_opcode::global_load_sbyte: load->opcode = aco_opcode::global_load_sbyte_d16; break; case aco_opcode::global_load_ubyte: load->opcode = aco_opcode::global_load_ubyte_d16; break; + case aco_opcode::global_load_sshort: load->opcode = aco_opcode::global_load_short_d16; break; case aco_opcode::global_load_ushort: load->opcode = aco_opcode::global_load_short_d16; break; + case aco_opcode::scratch_load_sbyte: load->opcode = aco_opcode::scratch_load_sbyte_d16; break; case aco_opcode::scratch_load_ubyte: load->opcode = aco_opcode::scratch_load_ubyte_d16; break; + case aco_opcode::scratch_load_sshort: load->opcode = aco_opcode::scratch_load_short_d16; break; case aco_opcode::scratch_load_ushort: load->opcode = aco_opcode::scratch_load_short_d16; break; default: break; }