diff --git a/src/asahi/compiler/agx_compiler.h b/src/asahi/compiler/agx_compiler.h index ea9531c09dc..e306aa2c8be 100644 --- a/src/asahi/compiler/agx_compiler.h +++ b/src/asahi/compiler/agx_compiler.h @@ -321,6 +321,7 @@ typedef struct { enum agx_icond icond; enum agx_fcond fcond; enum agx_round round; + enum agx_atomic_opc atomic_opc; enum agx_lod_mode lod_mode; struct agx_block *target; }; diff --git a/src/asahi/compiler/agx_lower_uniform_sources.c b/src/asahi/compiler/agx_lower_uniform_sources.c index 5558151c224..7bbe564304b 100644 --- a/src/asahi/compiler/agx_lower_uniform_sources.c +++ b/src/asahi/compiler/agx_lower_uniform_sources.c @@ -28,6 +28,7 @@ should_lower(enum agx_opcode op, agx_index uniform, unsigned src_index) case AGX_OPCODE_DEVICE_LOAD: return src_index != 0 || high; case AGX_OPCODE_DEVICE_STORE: + case AGX_OPCODE_ATOMIC: return src_index != 1 || high; case AGX_OPCODE_ZS_EMIT: case AGX_OPCODE_ST_TILE: diff --git a/src/asahi/compiler/agx_opcodes.py b/src/asahi/compiler/agx_opcodes.py index 2614c947565..cff9d00ae0d 100644 --- a/src/asahi/compiler/agx_opcodes.py +++ b/src/asahi/compiler/agx_opcodes.py @@ -152,6 +152,20 @@ SR = enum("sr", { 82: 'thread_position_in_grid.z', }) +ATOMIC_OPC = enum("atomic_opc", { + 0: 'add', + 1: 'sub', + 2: 'xchg', + 3: 'cmpxchg', + 4: 'umin', + 5: 'imin', + 6: 'umax', + 7: 'imax', + 8: 'and', + 9: 'or', + 10: 'xor', +}) + FUNOP = lambda x: (x << 28) FUNOP_MASK = FUNOP((1 << 14) - 1) @@ -265,6 +279,16 @@ op("uniform_store", encoding_32 = ((0b111 << 27) | 0b1000101 | (1 << 47), 0, 8, _), dests = 0, srcs = 2, can_eliminate = False) +# sources are value, base, index +op("atomic", + encoding_32 = (0x15 | (1 << 26) | (1 << 31) | (5 << 44), 0x3F | (1 << 26) | (1 << 31) | (5 << 44), 8, _), + dests = 1, srcs = 3, imms = [ATOMIC_OPC, SCOREBOARD], can_eliminate = False) + +# XXX: stop hardcoding the long form +op("local_atomic", + encoding_32 = (0x19 | (1 << 15) | (1 << 36) | (1 << 47), 0x3F | (1 << 36) | (1 << 47), 10, _), + dests = 1, srcs = 3, imms = [ATOMIC_OPC], can_eliminate = False) + op("wait", (0x38, 0xFF, 2, _), dests = 0, can_eliminate = False, imms = [SCOREBOARD]) diff --git a/src/asahi/compiler/agx_optimizer.c b/src/asahi/compiler/agx_optimizer.c index c9d7248b4e6..bc0a2db025c 100644 --- a/src/asahi/compiler/agx_optimizer.c +++ b/src/asahi/compiler/agx_optimizer.c @@ -132,7 +132,9 @@ agx_optimizer_inline_imm(agx_instr **defs, agx_instr *I, unsigned srcs, continue; if (I->op == AGX_OPCODE_ZS_EMIT && s != 0) continue; - if (I->op == AGX_OPCODE_DEVICE_STORE && s != 2) + if ((I->op == AGX_OPCODE_DEVICE_STORE || I->op == AGX_OPCODE_ATOMIC || + I->op == AGX_OPCODE_LOCAL_ATOMIC) && + s != 2) continue; if (float_src) { @@ -190,7 +192,8 @@ agx_optimizer_copyprop(agx_instr **defs, agx_instr *I) /* ALU instructions cannot take 64-bit */ if (def->src[0].size == AGX_SIZE_64 && !(I->op == AGX_OPCODE_DEVICE_LOAD && s == 0) && - !(I->op == AGX_OPCODE_DEVICE_STORE && s == 1)) + !(I->op == AGX_OPCODE_DEVICE_STORE && s == 1) && + !(I->op == AGX_OPCODE_ATOMIC && s == 1)) continue; agx_replace_src(I, s, def->src[0]); diff --git a/src/asahi/compiler/agx_register_allocate.c b/src/asahi/compiler/agx_register_allocate.c index e801de099f7..8f2c7b45831 100644 --- a/src/asahi/compiler/agx_register_allocate.c +++ b/src/asahi/compiler/agx_register_allocate.c @@ -169,6 +169,13 @@ agx_read_registers(agx_instr *I, unsigned s) return size; } + case AGX_OPCODE_ATOMIC: + case AGX_OPCODE_LOCAL_ATOMIC: + if (s == 0 && I->atomic_opc == AGX_ATOMIC_OPC_CMPXCHG) + return size * 2; + else + return size; + default: return size; } diff --git a/src/asahi/compiler/agx_validate.c b/src/asahi/compiler/agx_validate.c index 6a09f1f189c..d57fb97ce97 100644 --- a/src/asahi/compiler/agx_validate.c +++ b/src/asahi/compiler/agx_validate.c @@ -104,7 +104,8 @@ agx_validate_sources(agx_instr *I) agx_validate_assert(!src.discard); bool ldst = (I->op == AGX_OPCODE_DEVICE_LOAD) || - (I->op == AGX_OPCODE_UNIFORM_STORE); + (I->op == AGX_OPCODE_UNIFORM_STORE) || + (I->op == AGX_OPCODE_ATOMIC); /* Immediates are encoded as 8-bit (16-bit for memory load/store). For * integers, they extend to 16-bit. For floating point, they are 8-bit