From 064336d35977abd0d5b6ed37784c6cc42cf4f66f Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 23 Nov 2022 20:41:29 +0000 Subject: [PATCH] ac/nir: mask shift operands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NIR shifts are defined to truncate the shift amount to the number of bits needed to represent the bit-size of the value shifted. LLVM treats large shifts as poison. This fix achieves NIR semantics for shifts. As an example, a|(b << 32), where "a" is 32bits, should produce a|b according to NIR (because 32&31 == 0). This caused LLVM to incorrectly optimize "(a >> c) | (b << (32 - c))" to a u2u32(pack_64_2x32(a, b) >> c) (v_alignbit_b32), when the original NIR should have returned "a | b" if c==0. Signed-off-by: Rhys Perry Reviewed-by: Mihai Preda Reviewed-by: Qiang Yu Reviewed-by: Marek Olšák Cc: mesa-stable Part-of: --- src/amd/llvm/ac_nir_to_llvm.c | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c index 7b69c1ca013..b0410b594b0 100644 --- a/src/amd/llvm/ac_nir_to_llvm.c +++ b/src/amd/llvm/ac_nir_to_llvm.c @@ -713,33 +713,45 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr) case nir_op_ixor: result = LLVMBuildXor(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ishl: + case nir_op_ishl: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildShl(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ishr: + } + case nir_op_ishr: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildAShr(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ushr: + } + case nir_op_ushr: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildLShr(ctx->ac.builder, src[0], src[1], ""); break; + } case nir_op_ilt: result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]); break;