As AMD’s RDNA4 GPUs released, FSR 4 was released alongside it to much fanfare. It’s a massive leap in quality over FSR 3.1, and with FSR 4 moving to a machine learning model instead of an analytical model, FSR 3.1 marks the farewell to fully open and grok-able upscalers. It’s truly a shame, but the graphics world cares little about such sentimentality when pretty pixels are at stake.
FSR 4 is not actually released yet to developers for integration in engines, but exists as a driver toggle that can be enabled for certain FSR 3.1 games. The FSR 3.1 implementation had the foresight to let the implementation peek into the driver DLLs and fish out a replacement implementation. This is of course a horrible mess from any compatibility perspective, but we just have to make it work.
From day 1 of RDNA4’s release, there’s been a lot of pressure from the Linux gaming community to get this working on Linux as well. Buying an RDNA4 GPU would be far less enticing if there’s no way to get FSR 4 working after all …
First battle – how is FSR 4 even invoked?
Given the (currently) proprietary nature of FSR 4, we could easily have been in the situation of DLSS where it’s literally impossible to re-implement it. DLSS uses interop with CUDA for example and there’s 0% chance anyone outside of NVIDIA can deal with that. Fortunately NVIDIA provides the shims required to make DLSS work on Proton these days, but we’re on our own at this time for FSR 4.
It all started with an issue made on vkd3d-proton’s tracker asking for FSR 4 support.
Somehow, with the OptiScaler project, they had got to the point where vkd3d-proton failed to compile compute shaders. This was very encouraging, because it means that FSR 4 goes through D3D12 APIs somehow. Of course, D3D12’s way of vendor extensions is the most disgusting thing ever devised, but we’ll get to that later …
The flow of things seems to be:
- FSR 3.1 DLL tries to open the AMD d3d12 driver DLL (amdxc64.dll) and queries for a COM interface.
- Presumably, after checking that FSR 4 is enabled in control panel and checking if the .exe is allowed to promote, it loads amdxcffx64.dll, which contains the actual implementation.
- amdxcffx64.dll creates normal D3D12 compute shaders against the supplied ID3D12Device (phew!), with a metric ton of undocumented AGS magic opcodes. Until games start shipping their own FSR 4 implementation, I’d expect that users need to copy over that DLL from an AMD driver install somehow, but that’s outside the scope of my work.
- amdxcffx64.dll also seems to call back into the driver DLL to ask for which configuration to use. Someone else managed to patch this check out and eventually figure out how to implement this part in a custom driver shim.
With undocumented opcodes, I really didn’t feel like trying anything. Attempting to reverse something like that is just too much work. There was high risk of spending weeks on something that didn’t work out, but someone found a very handy file checked into the open source LLPC repos from AMD. This file is far from complete, but it’s a solid start.
At this point it was clear that FSR 4 is based on the WMMA (wave matrix multiply accumulate) instructions found on RDNA3 and 4 GPUs. This is encouraging since we have VK_KHR_cooperative_matrix in Vulkan which maps directly to WMMA. D3D12 was supposed to get this feature in SM 6.8, but it was dropped for some inexplicable reason. It’s unknown if FSR 4 could have used standard WMMA opcodes in DXIL if we went down that timeline. Certainly would have saved me a lot of pain and suffering …
Dumping shaders
Next step was being able to capture the shaders in question. Ratchet & Clank – Rift Apart was on the list of supported games from AMD and was the smallest install I had readily available, so I fired it up with RGP on Windows and managed to observe WMMA opcodes. Encouraging!
Next step was dumping the actual DXIL shaders, however, it seems like the driver blocks FSR 4 if RenderDoc is attached (or some other unknown weirdness occurs), so a different strategy was necessary.
First attempt was to take a FSR 3.1 application without anti-tampering and hook that somehow. The FSR 3.1 demo app in the SDK was suitable for this task. The driver refused to use FSR 4 here, but when I renamed the demo .exe to RiftApart.exe it worked. quack.exe lives on!
Now that I had DXIL and some example ISA to go along with it, it was looking possible to slowly piece together how it all works.
Deciphering AGS opcodes
Shader extensions in D3D is a disgusting mess and this is roughly how it works.
- Declare a magic UAV at some magic register + space combo
- Emit a ton of oddly encoded atomic compare exchanges to that UAV
- DXC has no idea what any of this means, but it must emit the DXIL as is
- Compare exchange is likely chosen because it’s never okay to reorder, merge or eliminate those operations
- Driver compiler recognizes these magic back-doors, consumes the magic stream of atomic compare exchanges, and translates that into something that makes sense
dxil-spirv in vkd3d-proton already had some code to deal with this as Mortal Kombat 11 has some AGS shaders for 64-bit atomics (DXIL, but before SM 6.6).
RWByteAddressBuffer MAGIC : register(u0, space2147420894); uint Code(uint opcode, uint opcodePhase, uint immediateData) { return (MagicCode << MagicCodeShift) | ((immediateData & DataMask) << DataShift) | ((opcodePhase & OpcodePhaseMask) << OpcodePhaseShift) | ((opcode & OpcodeMask) << OpcodeShift); } uint AGSMagic(uint code, uint arg0, uint arg1) { uint ret; MAGIC.InterlockedCompareExchange(code, arg0, arg1, ret); return ret; } uint AGSMagic(uint opcode, uint phase, uint imm, uint arg0, uint arg1) { return AGSMagic(Code(opcode, phase, imm), arg0, arg1); }
Every WMMA opcode is translated to a ton of these magic instructions back-to-back. A maximum of 21 (!) in fact. dxil-spirv has to pattern match more or less.
The exact details of how WMMA is represented with these opcodes isn’t super exciting, but for testing this functionality, I implemented a header.
It seems like a wave matrix is represented with 8 uints. This tracks, as for 16×16 matrix and wave32 + FP32, you need 256 bits per lane in the worst case.
struct WMMA_Matrix { uint v[8]; };
Here’s how WMMA matmul can be represented for example:
WMMA_Matrix WMMA_MatMulAcc(WaveMatrixOpcode op, WMMA_Matrix A, WMMA_Matrix B, WMMA_Matrix C) { // A matrix AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(0, 0, WaveMatrixRegType_A_TempReg), A.v[0], A.v[1]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(1, 0, WaveMatrixRegType_A_TempReg), A.v[2], A.v[3]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(0, 1, WaveMatrixRegType_A_TempReg), A.v[4], A.v[5]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(1, 1, WaveMatrixRegType_A_TempReg), A.v[6], A.v[7]); // B matrix AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(0, 0, WaveMatrixRegType_B_TempReg), B.v[0], B.v[1]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(1, 0, WaveMatrixRegType_B_TempReg), B.v[2], B.v[3]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(0, 1, WaveMatrixRegType_B_TempReg), B.v[4], B.v[5]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(1, 1, WaveMatrixRegType_B_TempReg), B.v[6], B.v[7]); // C matrix AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(0, 0, WaveMatrixRegType_Accumulator_TempReg), C.v[0], C.v[1]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(1, 0, WaveMatrixRegType_Accumulator_TempReg), C.v[2], C.v[3]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(0, 1, WaveMatrixRegType_Accumulator_TempReg), C.v[4], C.v[5]); AGSMagic(WaveMatrixMulAcc, 0, MatrixIO(1, 1, WaveMatrixRegType_Accumulator_TempReg), C.v[6], C.v[7]); // Configure type AGSMagic(WaveMatrixMulAcc, 1, int(op) << int(WaveMatrixOpcode_OpsShift), 0, 0); // Read output WMMA_Matrix ret; ret.v[0] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(0, 0, WaveMatrixRegType_RetVal_Reg), 0, 0); ret.v[1] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(1, 0, WaveMatrixRegType_RetVal_Reg), 0, 0); ret.v[2] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(2, 0, WaveMatrixRegType_RetVal_Reg), 0, 0); ret.v[3] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(3, 0, WaveMatrixRegType_RetVal_Reg), 0, 0); ret.v[4] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(0, 1, WaveMatrixRegType_RetVal_Reg), 0, 0); ret.v[5] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(1, 1, WaveMatrixRegType_RetVal_Reg), 0, 0); ret.v[6] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(2, 1, WaveMatrixRegType_RetVal_Reg), 0, 0); ret.v[7] = AGSMagic(WaveMatrixMulAcc, 2, MatrixIO(3, 1, WaveMatrixRegType_RetVal_Reg), 0, 0); return ret; }
Fortunately, FSR 4 never tries to be clever with the WMMA_Matrix elements. It seems possible to pass in whatever reason uint you want, but the shaders follow a strict pattern so dxil-spirv can do type promotion from uint to OpTypeCooperativeMatrixKHR on the first element and ignore the rest.
It took many days of agonizing trial and error, but eventually I managed to put together a test suite that exercises all the opcodes that I found in the DXIL files I dumped.
Tough corner cases
Ideally, there would be a straight forward implementation to KHR_cooperative_matrix, but that’s not really the case.
FP8 (E4M3)
FSR 4 is heavily reliant on FP8 WMMA. This is an exclusive RDNA4 feature. RDNA3 has WMMA, but only FP16. There is currently no SPIR-V support for Float8 either (but given that BFloat16 just released it’s not a stretch to assume something will happen in this area at some point).
To make something that is actually compliant with Vulkan at this point, I implemented emulation of FP8.
Converting FP8 to FP16 is fairly straight forward. While we don’t have float8 yet, we have uint8. Doing element-wise conversions like this is not strictly well defined, since the wave layout of different types can change, but it works fine in practice.
coopmat<float16_t, gl_ScopeSubgroup, 16u, 16u, gl_MatrixUseA> CoopMatFP8toFP16( coopmat<uint8_t, gl_ScopeSubgroup, 16u, 16u, gl_MatrixUseA> coop_input) { coopmat<float16_t, gl_ScopeSubgroup, 16u, 16u, gl_MatrixUseA> coop_output; for (int i = 0; i < coop_input.length(); i++) { int16_t v = (int16_t(int8_t(coop_input[i])) << 7s) & (-16385s); coop_output[i] = int16BitsToFloat16(v); } // Handles denorm correctly. // Ignores NaN, but who cares. // There is no Inf in E4M3. return coop_output * float16_t(256.0); }
I’m quite happy with this bit-hackery. Sign-extend, shift, bit-and, and an fmul.
FSR 4 also relies on FP32 -> FP8 conversions to quantize accumulation results back to FP8 for storage between stages. This is … significantly more terrible to emulate. Doing accurate RTE with denorm handling in soft-float is GPU sadism. It explodes driver compile times and runtime performance is borderline unusable as a result.
8-bit Accumulation matrix conversions
In many places, we need to handle loading 8-bit matrices, which are converted to FP32 accumulators and back. Vulkan can support this, but it relies on drivers exposing it via the physical device query. No driver I know of exposes 8-bit accumulator support in any operations, which means we’re forced to go out of spec. With some light tweaks to RADV, it works as expected however. A driver should be able to expose 8-bit accumulator types and do the trunc/extend as needed. It’s somewhat awkward that there is no way to expose format support separately, but it is what it is.
Converting types with Use conversion
In several places, the shaders need to convert e.g. Accumulation matrices to B matrices. This is a use case not covered by KHR_cooperative_matrix. The universal workaround is to roundtrip store/load via LDS, but that’s kind of horrible for perf. I ended up with some hacky code paths for now:
- On RDNA4, Accum layout is basically same as B layout, so I can abuse implementation specific behavior by copying over elements one by one into a new coopmat with different type.
- On RDNA3, this doesn’t work since len(B) != len(Accum).
- NV_cooperative_matrix2 actually has a feature bit to support this exact use case without hackery, so I can take advantage of that when RADV implements support for it.
The debugging process
After implementing all the opcodes and making all my tests pass, it was time throw this at the wall. (Screenshots from TkG, pardon the French)
… fun. First process was to figure out if there was something I may have missed about opcode behavior. For SSBOs and LDS, I simply assumed that 4 byte alignment would be fine, then I found:
coopMatLoad(_86, _39._m0, ((30976u * _76) + _79) + gl_WorkGroupID.z, 16u, gl_CooperativeMatrixLayoutColumnMajor);
1 byte aligned coopmat (8-bit matrix), yikes … Technically out of spec for cooperative_matrix, but nothing we can do about that. AMD deals with this just fine. In the original code, I had used indexing into a u32 SSBO and just divided the byte offset by 4, but obviously that breaks. I added support for 8-bit SSBO aliases in dxil-spirv, updated test suite and we get:
Still not right. Eventually I found some questionable looking code with LDS load-store. The strides couldn’t possibly be right. It turns out that offset/stride for LDS is in terms of u32 words, not bytes (?!). This detail wasn’t caught in the test suite because that bug would cancel each other out on a store and load. Fortunately, this quirk made my life easier in dxil-spirv, since there’s no easy way of emitting 8-bit LDS aliases.
Things were looking good, but it was far from done. There was still intense shimmering and ghosting, which couldn’t be observed from screenshots and I couldn’t figure it out by simply staring at code. It was time to bring out the big guns.
Build a tiny FSR 4 test bench from scratch
I needed a way to directly compare outputs from native and vkd3d-proton, to be able to isolate exactly where implementations diverged. Fortunately, since we’re not getting blocked by driver when trying to capture anymore, I captured a RenderDoc frame from the 3.1 demo.
Fortunately the inputs and outputs are very simple.
- A pre-pass that reads textures, and does very light WMMA work at the end
- A bunch of raw passes that spam WMMA like no tomorrow, likely to implement the ML network
- A final post-pass that synthesizes the final image with more WMMA work of course
The middle passes were very simple from a resource binding standpoint:
- One big weight buffer
- One big scratch buffer
In RenderDoc, I dumped the buffer contents to disk, and built a small D3D12 test app that invokes the shaders with the buffers in question. Every dispatch, dump the scratch buffers contents out to disk, and by running it against the native D3D12 driver and vkd3d-proton, figure out where things diverge.
Surely the FSR 4 shaders cannot be bugged, right?
Turns out, yes, they can be 🙂 Turns out many of the shaders only allocate 256 bytes of LDS, yet the shaders actually need 512. Classic undefined behavior. The reason this “happens” to work on native is that AMD allocates LDS space with 512 byte granularity. However, dxil-spirv also emits some LDS to deal with matrix transpositions and it ended up clobbering the AGS shader’s LDS space …
One disgusting workaround later …
// Workaround for bugged WMMA shaders. // The shaders rely on AMD aligning LDS size to 512 bytes. // This avoids overflow spilling into LDSTranspose area by mistake, which breaks some shaders. if ( address_space == DXIL::AddressSpace::GroupShared && shader_analysis.require_wmma) { // ... Pad groupshared array to 512 byte. }
and games were rendering correctly. FP16 path on RDNA4 conquered.
Performance?
Absolute garbage, as expected. 1440p on 9070xt on native is about 0.85 ms and my implementation was about 3 ms.
Going beyond
Can we implement FP8?
RADV obviously cannot ship FP8 before there is a Vulkan / SPIR-V spec, but with the power of open source and courage, why not experiment. Nothing is stopping us from just emitting:
OpTypeFloat 8
and see what happens. Georg Lehmann brought up FP8 support in NIR and ACO enough to support FP8 WMMA. Hacking FP8 support into dxil-spirv was quite straight forward and done in an hour. Getting the test suite to pass was smooth and easy, but … the real battle was yet to come.
vkd3d-proton bug or RADV?
Games were completely broken in the FP8 path. Fortunately, this difference reproduced in my test bench. The real issue now was bisecting the shader itself to figure out where the shaders diverge. These shaders are not small. It’s full of incomprehensible ML gibberish, so only solution I could come up with was capturing both FP16 and FP8 paths and debug printing side-by-side. Fortunately, RenderDoc makes this super easy.
First, I had to hack together FP8 support in SPIRV-Cross and glslang so that roundtripping could work:
Eventually, I found the divergent spot, and Georg narrowed it down to broken FP8 vectorization in ACO. Once this was fixed, FP8 was up and running. Runtime was now down to 1.3 ms.
Get rid of silly LDS roundtrips
FSR 4 really likes to convert Accumulator to B matrices, but on RDNA4, the layouts match (at least for 8-bit) so until we have NV_cooperative_matrix2 implementation, I pretended it worked by copying elements instead, and runtime went down to about 1 ms. RADV codegen for coopmat is currently very naive, especially buffer loading code is extremely inefficient, but despite that, we’re pretty close to native here. Now that there is a good use case for cooperative matrix, I’m sure it will get optimized eventually.
At this point, FP8 path is fully functional and performant enough, but of course it needs building random Mesa branches and enabling hacked up code paths in vkd3d-proton.
RDNA3?
RDNA3 is not officially supported at the moment, but given I already went through the pain of emulating FP8, there’s no reason it cannot work on RDNA3. Given the terrible performance I got in FP16 emulation, I can understand why RDNA3 is not supported though … FSR 4 requires a lot of WMMA brute force to work, and RDNA3’s lesser WMMA grunt is simply not strong enough. Maybe it would work better if a dedicated FP16 model is designed, but that’s not on me to figure out.
It took a good while to debug this, since once again the test suite was running fine …
LDS roundtrip required
Unfortunately, RDNA3 is quite strange when it comes to WMMA layouts. Accumulator has 8 elements, but A and B matrices have 16 for some reason. NV_cooperative_matrix2 will help here for sure.
Final shader bug
After fixing the LDS roundtrip, the test bench passed, but games still looked completely broken on RDNA3. This narrowed down the problem to either the pre-pass or post-pass. Dual GPU and opening up a capture of the FSR 3.1 demo side by side on RDNA3 and RDNA4, I finally narrowed it down to questionable shader code that unnecessarily relies on implementation defined behavior.
// Rewritten for clarity float tmp[8]; tmp[0u] = uintBitsToFloat(_30._m0[384u].x); tmp[1u] = uintBitsToFloat(_30._m0[384u].y); tmp[2u] = uintBitsToFloat(_30._m0[384u].z); tmp[3u] = uintBitsToFloat(_30._m0[384u].w); tmp[4u] = uintBitsToFloat(_30._m0[385u].x); tmp[5u] = uintBitsToFloat(_30._m0[385u].y); tmp[6u] = uintBitsToFloat(_30._m0[385u].z); tmp[7u] = uintBitsToFloat(_30._m0[385u].w); for (int i = 0; i < mat.length(); i++) mat[i] += tmp[i]; storeToLDS(mat); // Read LDS and write to image. LDS[8 * SubgroupInvocationID + {4, 5, 6, 7}]
This is not well behaved cooperative matrix code since it relies on the register layout. RDNA3 and 4 actually differ. The columns are interleaved in very different ways.
I found a workaround which can be applied to RDNA3 wave32.
for (int i = 0; i < mat.length(); i++) mat[i] += tmp[((i << 1u) & 7u) + (gl_SubgroupInvocationID >> 4u)];
That was the best I could do without resorting to full shader replacement. The actual fix would be for the shader to just perform this addition after loading from LDS, like:
storeToLDS(mat); float tmp[8]; tmp[0u] = uintBitsToFloat(_30._m0[384u].x); tmp[1u] = uintBitsToFloat(_30._m0[384u].y); tmp[2u] = uintBitsToFloat(_30._m0[384u].z); tmp[3u] = uintBitsToFloat(_30._m0[384u].w); tmp[4u] = uintBitsToFloat(_30._m0[385u].x); tmp[5u] = uintBitsToFloat(_30._m0[385u].y); tmp[6u] = uintBitsToFloat(_30._m0[385u].z); tmp[7u] = uintBitsToFloat(_30._m0[385u].w); // Read LDS and write to image. LDS[8 * SubgroupInvocationID + {4, 5, 6, 7}] + tmp[{4, 5, 6, 7}]
This would at least be portable.
With this, RDNA 3 can do FSR 4 on vkd3d-proton if you’re willing to take a massive dump on performance.
Conclusion
I was fearing getting FSR 4 up and running would take a year, but here we are. Lots of different people in the community ended up contributing to this in smaller ways to unblock the debugging process.
There probably won’t be a straight forward way to make use of this work until FSR 4 is released in an official SDK, FP8 actually lands in Vulkan, etc, etc, so I’ll leave the end-user side of things out of this blog.