From 143167f2a05b35287c349fd52393a945f26bc4f5 Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Tue, 7 Sep 2021 14:51:48 +1000 Subject: [PATCH] gallivm/nir: handle subgroup reduction across all types Reviewed-by: Roland Scheidegger Part-of: --- .../auxiliary/gallivm/lp_bld_nir_soa.c | 86 +++++++++++++++++-- 1 file changed, 78 insertions(+), 8 deletions(-) diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index a861f071ec7..dde6831be3b 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -2030,36 +2030,106 @@ static void emit_reduce(struct lp_build_nir_context *bld_base, LLVMValueRef src, switch (reduction_op) { case nir_op_fmin: { LLVMValueRef flt_max = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), INFINITY) : - lp_build_const_float(gallivm, INFINITY); + (bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), INFINITY) : lp_build_const_float(gallivm, INFINITY)); store_val = LLVMBuildBitCast(builder, flt_max, int_bld->elem_type, ""); break; } case nir_op_fmax: { LLVMValueRef flt_min = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), -INFINITY) : - lp_build_const_float(gallivm, -INFINITY); + (bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), -INFINITY) : lp_build_const_float(gallivm, -INFINITY)); store_val = LLVMBuildBitCast(builder, flt_min, int_bld->elem_type, ""); break; } case nir_op_fmul: { LLVMValueRef flt_one = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), 1.0) : - lp_build_const_float(gallivm, 1.0); + (bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), 1.0) : lp_build_const_float(gallivm, 1.0)); store_val = LLVMBuildBitCast(builder, flt_one, int_bld->elem_type, ""); break; } case nir_op_umin: - store_val = lp_build_const_int32(gallivm, UINT_MAX); + switch (bit_size) { + case 8: + store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), UINT8_MAX, 0); + break; + case 16: + store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), UINT16_MAX, 0); + break; + case 32: + default: + store_val = lp_build_const_int32(gallivm, UINT_MAX); + break; + case 64: + store_val = lp_build_const_int64(gallivm, UINT64_MAX); + break; + } break; case nir_op_imin: - store_val = lp_build_const_int32(gallivm, INT_MAX); + switch (bit_size) { + case 8: + store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MAX, 0); + break; + case 16: + store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MAX, 0); + break; + case 32: + default: + store_val = lp_build_const_int32(gallivm, INT_MAX); + break; + case 64: + store_val = lp_build_const_int64(gallivm, INT64_MAX); + break; + } break; case nir_op_imax: - store_val = lp_build_const_int32(gallivm, INT_MIN); + switch (bit_size) { + case 8: + store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MIN, 0); + break; + case 16: + store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MIN, 0); + break; + case 32: + default: + store_val = lp_build_const_int32(gallivm, INT_MIN); + break; + case 64: + store_val = lp_build_const_int64(gallivm, INT64_MIN); + break; + } break; case nir_op_imul: - store_val = lp_build_const_int32(gallivm, 1); + switch (bit_size) { + case 8: + store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 1, 0); + break; + case 16: + store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 1, 0); + break; + case 32: + default: + store_val = lp_build_const_int32(gallivm, 1); + break; + case 64: + store_val = lp_build_const_int64(gallivm, 1); + break; + } break; case nir_op_iand: - store_val = lp_build_const_int32(gallivm, 0xffffffff); + switch (bit_size) { + case 8: + store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 0xff, 0); + break; + case 16: + store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 0xffff, 0); + break; + case 32: + default: + store_val = lp_build_const_int32(gallivm, 0xffffffff); + break; + case 64: + store_val = lp_build_const_int64(gallivm, 0xffffffffffffffffLL); + break; + } break; default: break;