aco: add missing conversion operations for small bitsizes
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-By: Timur Kristóf <timur.kristof@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4002>
This commit is contained in:
@@ -1900,8 +1900,27 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_f2f16:
|
||||
case nir_op_f2f16_rtne: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 64)
|
||||
src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src);
|
||||
src = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src);
|
||||
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
|
||||
break;
|
||||
}
|
||||
case nir_op_f2f16_rtz: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 64)
|
||||
src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src);
|
||||
src = bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, bld.def(v1), src, Operand(0u));
|
||||
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
|
||||
break;
|
||||
}
|
||||
case nir_op_f2f32: {
|
||||
if (instr->src[0].src.ssa->bit_size == 64) {
|
||||
if (instr->src[0].src.ssa->bit_size == 16) {
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f16, dst);
|
||||
} else if (instr->src[0].src.ssa->bit_size == 64) {
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f64, dst);
|
||||
} else {
|
||||
fprintf(stderr, "Unimplemented NIR instr bit size: ");
|
||||
@@ -1911,13 +1930,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
|
||||
break;
|
||||
}
|
||||
case nir_op_f2f64: {
|
||||
if (instr->src[0].src.ssa->bit_size == 32) {
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_f32, dst);
|
||||
} else {
|
||||
fprintf(stderr, "Unimplemented NIR instr bit size: ");
|
||||
nir_print_instr(&instr->instr, stderr);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 16)
|
||||
src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src);
|
||||
bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src);
|
||||
break;
|
||||
}
|
||||
case nir_op_i2f32: {
|
||||
@@ -1969,6 +1985,36 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_f2i16: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 16)
|
||||
src = bld.vop1(aco_opcode::v_cvt_i16_f16, bld.def(v1), src);
|
||||
else if (instr->src[0].src.ssa->bit_size == 32)
|
||||
src = bld.vop1(aco_opcode::v_cvt_i32_f32, bld.def(v1), src);
|
||||
else
|
||||
src = bld.vop1(aco_opcode::v_cvt_i32_f64, bld.def(v1), src);
|
||||
|
||||
if (dst.type() == RegType::vgpr)
|
||||
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
|
||||
else
|
||||
bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src);
|
||||
break;
|
||||
}
|
||||
case nir_op_f2u16: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 16)
|
||||
src = bld.vop1(aco_opcode::v_cvt_u16_f16, bld.def(v1), src);
|
||||
else if (instr->src[0].src.ssa->bit_size == 32)
|
||||
src = bld.vop1(aco_opcode::v_cvt_u32_f32, bld.def(v1), src);
|
||||
else
|
||||
src = bld.vop1(aco_opcode::v_cvt_u32_f64, bld.def(v1), src);
|
||||
|
||||
if (dst.type() == RegType::vgpr)
|
||||
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
|
||||
else
|
||||
bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src);
|
||||
break;
|
||||
}
|
||||
case nir_op_f2i32: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 32) {
|
||||
@@ -2190,9 +2236,91 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_i2i8:
|
||||
case nir_op_u2u8: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
/* we can actually just say dst = src */
|
||||
if (src.regClass() == s1)
|
||||
bld.copy(Definition(dst), src);
|
||||
else
|
||||
emit_extract_vector(ctx, src, 0, dst);
|
||||
break;
|
||||
}
|
||||
case nir_op_i2i16: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 8) {
|
||||
if (dst.regClass() == s1) {
|
||||
bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src));
|
||||
} else {
|
||||
assert(src.regClass() == v1b);
|
||||
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
|
||||
sdwa->operands[0] = Operand(src);
|
||||
sdwa->definitions[0] = Definition(dst);
|
||||
sdwa->sel[0] = sdwa_sbyte;
|
||||
sdwa->dst_sel = sdwa_sword;
|
||||
ctx->block->instructions.emplace_back(std::move(sdwa));
|
||||
}
|
||||
} else {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
/* we can actually just say dst = src */
|
||||
if (src.regClass() == s1)
|
||||
bld.copy(Definition(dst), src);
|
||||
else
|
||||
emit_extract_vector(ctx, src, 0, dst);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_u2u16: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 8) {
|
||||
if (dst.regClass() == s1)
|
||||
bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src);
|
||||
else {
|
||||
assert(src.regClass() == v1b);
|
||||
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
|
||||
sdwa->operands[0] = Operand(src);
|
||||
sdwa->definitions[0] = Definition(dst);
|
||||
sdwa->sel[0] = sdwa_ubyte;
|
||||
sdwa->dst_sel = sdwa_uword;
|
||||
ctx->block->instructions.emplace_back(std::move(sdwa));
|
||||
}
|
||||
} else {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
/* we can actually just say dst = src */
|
||||
if (src.regClass() == s1)
|
||||
bld.copy(Definition(dst), src);
|
||||
else
|
||||
emit_extract_vector(ctx, src, 0, dst);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_i2i32: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 64) {
|
||||
if (instr->src[0].src.ssa->bit_size == 8) {
|
||||
if (dst.regClass() == s1) {
|
||||
bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src));
|
||||
} else {
|
||||
assert(src.regClass() == v1b);
|
||||
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
|
||||
sdwa->operands[0] = Operand(src);
|
||||
sdwa->definitions[0] = Definition(dst);
|
||||
sdwa->sel[0] = sdwa_sbyte;
|
||||
sdwa->dst_sel = sdwa_sdword;
|
||||
ctx->block->instructions.emplace_back(std::move(sdwa));
|
||||
}
|
||||
} else if (instr->src[0].src.ssa->bit_size == 16) {
|
||||
if (dst.regClass() == s1) {
|
||||
bld.sop1(aco_opcode::s_sext_i32_i16, Definition(dst), Operand(src));
|
||||
} else {
|
||||
assert(src.regClass() == v2b);
|
||||
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
|
||||
sdwa->operands[0] = Operand(src);
|
||||
sdwa->definitions[0] = Definition(dst);
|
||||
sdwa->sel[0] = sdwa_sword;
|
||||
sdwa->dst_sel = sdwa_udword;
|
||||
ctx->block->instructions.emplace_back(std::move(sdwa));
|
||||
}
|
||||
} else if (instr->src[0].src.ssa->bit_size == 64) {
|
||||
/* we can actually just say dst = src, as it would map the lower register */
|
||||
emit_extract_vector(ctx, src, 0, dst);
|
||||
} else {
|
||||
@@ -2204,12 +2332,29 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
|
||||
}
|
||||
case nir_op_u2u32: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 16) {
|
||||
if (instr->src[0].src.ssa->bit_size == 8) {
|
||||
if (dst.regClass() == s1)
|
||||
bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src);
|
||||
else {
|
||||
assert(src.regClass() == v1b);
|
||||
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
|
||||
sdwa->operands[0] = Operand(src);
|
||||
sdwa->definitions[0] = Definition(dst);
|
||||
sdwa->sel[0] = sdwa_ubyte;
|
||||
sdwa->dst_sel = sdwa_udword;
|
||||
ctx->block->instructions.emplace_back(std::move(sdwa));
|
||||
}
|
||||
} else if (instr->src[0].src.ssa->bit_size == 16) {
|
||||
if (dst.regClass() == s1) {
|
||||
bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFFFu), src);
|
||||
} else {
|
||||
// TODO: do better with SDWA
|
||||
bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0xFFFFu), src);
|
||||
assert(src.regClass() == v2b);
|
||||
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
|
||||
sdwa->operands[0] = Operand(src);
|
||||
sdwa->definitions[0] = Definition(dst);
|
||||
sdwa->sel[0] = sdwa_uword;
|
||||
sdwa->dst_sel = sdwa_udword;
|
||||
ctx->block->instructions.emplace_back(std::move(sdwa));
|
||||
}
|
||||
} else if (instr->src[0].src.ssa->bit_size == 64) {
|
||||
/* we can actually just say dst = src, as it would map the lower register */
|
||||
@@ -2298,6 +2443,32 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
|
||||
case nir_op_unpack_64_2x32_split_y:
|
||||
bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
break;
|
||||
case nir_op_unpack_32_2x16_split_x:
|
||||
if (dst.type() == RegType::vgpr) {
|
||||
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(dst.regClass()), get_alu_src(ctx, instr->src[0]));
|
||||
} else {
|
||||
bld.copy(Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
}
|
||||
break;
|
||||
case nir_op_unpack_32_2x16_split_y:
|
||||
if (dst.type() == RegType::vgpr) {
|
||||
bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
} else {
|
||||
bld.sop2(aco_opcode::s_bfe_u32, Definition(dst), get_alu_src(ctx, instr->src[0]), Operand(uint32_t(16 << 16 | 16)));
|
||||
}
|
||||
break;
|
||||
case nir_op_pack_32_2x16_split: {
|
||||
Temp src0 = get_alu_src(ctx, instr->src[0]);
|
||||
Temp src1 = get_alu_src(ctx, instr->src[1]);
|
||||
if (dst.regClass() == v1) {
|
||||
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src0, src1);
|
||||
} else {
|
||||
src0 = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), src0, Operand(0xFFFFu));
|
||||
src1 = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), src1, Operand(16u));
|
||||
bld.sop2(aco_opcode::s_or_b32, Definition(dst), bld.def(s1, scc), src0, src1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_pack_half_2x16: {
|
||||
Temp src = get_alu_src(ctx, instr->src[0], 2);
|
||||
|
||||
|
||||
@@ -305,6 +305,9 @@ void init_context(isel_context *ctx, nir_shader *shader)
|
||||
case nir_op_fround_even:
|
||||
case nir_op_fsin:
|
||||
case nir_op_fcos:
|
||||
case nir_op_f2f16:
|
||||
case nir_op_f2f16_rtz:
|
||||
case nir_op_f2f16_rtne:
|
||||
case nir_op_f2f32:
|
||||
case nir_op_f2f64:
|
||||
case nir_op_u2f32:
|
||||
@@ -328,13 +331,15 @@ void init_context(isel_context *ctx, nir_shader *shader)
|
||||
case nir_op_cube_face_coord:
|
||||
type = RegType::vgpr;
|
||||
break;
|
||||
case nir_op_f2i16:
|
||||
case nir_op_f2u16:
|
||||
case nir_op_f2i32:
|
||||
case nir_op_f2u32:
|
||||
case nir_op_f2i64:
|
||||
case nir_op_f2u64:
|
||||
case nir_op_b2i32:
|
||||
case nir_op_b2b32:
|
||||
case nir_op_b2f32:
|
||||
case nir_op_f2i32:
|
||||
case nir_op_f2u32:
|
||||
case nir_op_mov:
|
||||
type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr;
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user