diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c index 7b69c1ca013..b0410b594b0 100644 --- a/src/amd/llvm/ac_nir_to_llvm.c +++ b/src/amd/llvm/ac_nir_to_llvm.c @@ -713,33 +713,45 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr) case nir_op_ixor: result = LLVMBuildXor(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ishl: + case nir_op_ishl: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildShl(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ishr: + } + case nir_op_ishr: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildAShr(ctx->ac.builder, src[0], src[1], ""); break; - case nir_op_ushr: + } + case nir_op_ushr: { if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), ""); + LLVMTypeRef type = LLVMTypeOf(src[0]); + src[1] = LLVMBuildAnd(ctx->ac.builder, src[1], + LLVMConstInt(type, LLVMGetIntTypeWidth(type) - 1, false), ""); result = LLVMBuildLShr(ctx->ac.builder, src[0], src[1], ""); break; + } case nir_op_ilt: result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]); break;