FlashAttention-4 is a GPU-optimized attention mechanism designed specifically for NVIDIA's Blackwell architecture (B200), which addresses the fundamental bottleneck shift in attention computation where the forward pass becomes limited by exponential computation (SFU) and shared memory bandwidth, while the backward pass is primarily constrained by shared memory bandwidth. The key innovations include: (1) ping-pong computation that overlaps MMA operations with softmax calculations, (2) software emulation of exponential functions using polynomial approximation to parallelize with hardware SFUs, (3) conditional online softmax rescaling to minimize non-math operations, (4) 2CTA MMA mode that doubles the M dimension to 256 and reduces shared memory footprint, and (5) distributed shared memory for efficient inter-CTA communication. These optimizations leverage Blackwell's Tensor Memory and 2CTA MMA instructions to achieve significant performance improvements over previous FlashAttention versions and competing implementations like QDNN.
Deep Dive
Prerequisite Knowledge
- No data available.
Where to go next
- No data available.
Deep Dive
FlashAttention-4 by Ted Zadouri x GPU MODEAdded:
All right, welcome everyone. This is like our very first in-person GPU mode lecture and uh I couldn't think of a better guest to celebrate it with with the new first author of the flash attention 4 paper, Ted Zidori. Uh Ted's a PhD student like working under tree at Princeton. Uh and he's been doing lots of like interesting like new work here and so he's going to talk to us about like how he's co-designed this uh using what like war stories using cute DSL a bunch of math some code. So it should be like a fun lecture. So yeah, Ted, please take it from here.
>> Yeah, thank you Mark for the kind introduction. Um, today I'm gonna sort of present the new flash tension flash 4 and this is a joint effort with Marcus Jay, Timmy, VJ and tree.
Okay, just to give an outline of the talk, I'll give like a brief motivation why do the reasons for there being a flash tensioner for and we'll talk about the new hardware some background. uh daughter in in question is Blackwell, specifically B200. And then I'll talk about how flash tension 4 sort of redesigns the forward and backward pass and then talk about some benchmark results and say a few words about the cute DSL implementation of the flash tension and that should hopefully conclude the talk. Yeah. So in terms of motivation, if you sort of take the previous version of flash tension, it will either not be compatible with Blackwell for instance like FA3 which depends on sort of warp group level collective instructions or if you take sort of FA2 and sort of port on Blackwell, it will run but it will be really really slow. Uh you're kind of leaving a lot of performance on the ground and partly because it's not using the Blackwell instructions. But on the other hand sort of the kernel and the algorithmic code design was not taking into account sort of the bottlenecks on black wall. So if you do sort of a analytical roofly model and if we take like a 128 128 tile for BF16 we can see sort of the bottleneck shifts for attention on blackwell. So if you look at this figure which is uh represents asymmetric hardware scaling um and by as asymmetric I mean the tenses go throughuts are becoming so so fast that relative to other hardware units where they scale more slowly or they just don't scale at all. For example the exponential throughput which cause the special function unit and or the shared memory bandwidth.
So in terms of the forward pass uh it's bottleneck by the exponential and essenti what that really means is in the forward pass you have two gems computing the score matrix and the output tile and that effectively takes for let's say 128 128 tile the same clock cycle as softmax and on the other hand for the backward it's a little different story because you have five gems and over there the bottleneck shifts to shared memory bandwidth not so much as in the MMA or the exponential.
So these ideas sort of as a collective give us how to co-design the algorithmic and the kernel design of flash 4. All right. So the new hardware in question.
So there's two new instructions. One is tensor memory and the other one is fifth generation tensor cores. These are in some sense interlin because the tensor cores on blackwell are accumulating their outputs on tensor memory. Whereas compared to like hopper, the accumulators lived on register as you accumulate on the tensor memory as sort of a parallel warp can sort of load from TM and do the post-processing. It doesn't really stall the MMA. You can just immediately uh issue the next MMA to compute a different output that's on a different column of the tea. So different on how also on blackwell sort of a single thread n to issue with the MMA. So it's going to allow us to build much much deeper pipelines and we'll see like how sort of flash engine 4 takes advantage of this. Could you help us maybe quickly understand the trade-offs between like accumulating in registers versus sensor memory? Like what are the sizes of these respective like uh and then what are like the latencies? Um cuz I imagine you can reduce register pressure by just using less registers that's fine but like are you becoming slower in some other unexpected ways by using tensor memory?
>> So so loading from tensor memory is fairly fast. the main sort of reason uh tensor memory was introduced because of to deal with the register pressure but I think the like for instance if you have to do some rescaling on registers that sort of stalls the next MMA because like it's a warp group primitive like all the outputs need to participate so tensor memory sort of helps in that case so the next instruction that blackwell introduces is sort of the 2CTA MMA mode and essentially on hopper the m dimension of the MMA for BF16 was up to like 64 whereas in the case of Blackwell the support is 128. With 2CT MMA the support increases up to 256. So essentially like if you look at this example of the gem on the left hand side we have the A and the B matrix and the output C and on the right hand side we have the 2C MMA mode where for the A tile we sort of partition across the M dimension and for the B matrix we partition across the N dimension essentially the non-reduction axis and these are sort of split across the two CTAs and it's worth mentioning that only one of the CTS needs to issue the MMA not bolt.
>> Oh, I see. Okay.
>> Yeah. So, so, so for instance, like the force, let's say CTA 0 is the leader CTA, the CTA one is the MMA warp is doing nothing. It's empty. So, so, so one benefit of this is like let's say in this example, the M tile increases, but the N tile is still 128. So now each CTA needs to store half and load half for the MMA. So this reduces shared memory footprint and also it reduces like SM bandwidth.
Yeah. So before I sort of dive deep about the three key optimizations of the flash 4, uh let me give a very quick recap.
So in flash 4's forward pass essentially for a given batch and given head we partition the key value uh the query sequence length uh into blocks into m blocks and we sort of map each m block onto each CTA. So each CTA gets a unique query tile and inside that it sort of iterates across key value blocks and then sort of it computes the score matrix and then the attention probabilities using online softmax and then it accumulates the output until at the end each CTA sort of writes its unique output tile.
So what's new in flash forward? So the very first optimization is sort of computing the attention in a sort of ping-pong fashion. So instead of assigning single query tile, we assign two query tiles. As we sort of compute, let's call the query zero like high and low. And the idea is as we compute the output tile for the current iteration and the score matrix for the next iteration in parallel, we comput the soft max for the let's say the other tile. Let's call it the Q low. and he sort of tries to fully overlap MMA and softmax. And as mentioned earlier, we know that two gems cost about the same time as uh a softmax.
So kind of some implementation details.
So there's the MMA warp, there's a correction warp that does the rescaling for the online softmax that runs in parallel. So we want to make sure the rescaling is take it's not it's not in the critical path. And then we have uh two softmax warp groups each one with four warps for the high and the low.
Okay, but like in reality like the overlap is obviously not this perfect.
There can be a scenario where these two softax fight over like the same exponential. will be explicitly synchronized so they don't sort of call the SFU at the same time and also it could be a situation where the tensor cores like outpace and it softmax is not sort of in parallel being computed whereas sort of the MMA has to wait for the softmax output to be completed before it can start using sort of p. So this sort of takes to the next point where we introduce uh software emulation of the exponential and in this case sort of the high level idea is instead of relying solely on the special function unit we also have a software emulator version where in parallel so computes some of the portions. So if you sort of look at the softmax warp group picture on the left hand side we have four warps and one nice thing about tensor memory is that the layout is fairly simple. So for 128 by 128 tile, each row maps cleanly to each thread. So thread zero gets like the logical row make for the tile and then essentially based on some tunable parameter we just said okay 75% goes to the special function unit and the other 25% goes to sort of the software emulator version >> and this was a heristic the 75 25. So yeah, so that's a good point because these like require to compute many intermediate values and because there's not that many pressure registers. So the 25 is kind of like depending on the sequence like this is sort of to avoid register pressure.
>> Yeah. And the way the sort of the software emulation works is essentially for this like base 2 exponential you split it into fractional portion and the integer portion and then for the fractional portion you just sort of approximate using degree 3 polinomial and for the integer portion it's essentially multiplying the fraction portion with the integer portion essentially updating the exponent bits.
This seems counterintuitive like instead of just calling an X you're breaking it apart into uh like different chunks you're doing polomial fit you're doing bit shifts like uh help us understand how this could be faster >> faster than if you were sort of using 100% of the >> but yeah just the regular hardware unit.
So you have so many like fuse matrix accumulate like units that you can run this in parallel and this will be still faster than compared to like the calling the exponential because there's not that many of the SFUs on each SM.
>> So so like roughly how many SAS instructions is it to do the like so the MFU X2 there's like a single SAS instruction right?
>> Yeah. So the sax sax is called like mofu. Yeah.
>> Oh mu. Okay.
>> Yeah. Yeah.
What does the M stand for by the way?
Oh, what does mufu stand for?
>> Multifunction unit.
>> Multiunction unit.
>> Sometimes they call like multifunction operation, but but that's just the the the SAS name.
>> That that one wouldn't be PG though, you know.
>> Okay.
Okay.
Yeah. to give like a quick rundown of the algorithm. Um, essentially I'm leaving out some like implementation details out, but at first you sort of clamp uh dx to sort of avoid underflow and then we do like a range reduction where we have this sort of magic constant that we add to the input. Uh we add and round down and then subtract by itself to get sort of the integer portion and then uh we take that subtract it from the input to get the fractional portion. So now we have the integer and the fractional portion. For the fractional portion we can sort of approximate this with the polomial and the idea of the horners method is we take sort of this polinomial and put it in a form where we can sort of use less instructions to evaluate it. So we get the fractional portion. Now we want to sort of combine the integer and the fractional portion. And essentially the idea is we just need to update the bits of the exponent of the sort of the fractional portion. So it's literally like a shift add instruction. So this pushes the integer part onto the exponent portion when we sort of do this the bit representation. It's literally like a bit surgery and yeah like like nine SAS instructions for this seems like remarkable is it just because like I mean for like steps one two and three I can see this happening but polinomial approximation fitting on like singledigit SAS instructions seems surprising to me. So this is like three instructions.
Um this is another three instructions.
This is one instruction and this is like two instructions. So the total is nine.
So the SAS shows nine. The PTX might show 10.
>> And was this like an old trick? Like was it by any chance if you I don't know dig through the blast codebase like would you find this trick or is it like uh yeah like where where does the intuition for this uh for where this come from?
So these are like different ideas that you sort of combine >> together to sort of you try to essentially map the algorithm to the SAS instructions and so like most of them are like very old existing ideas if you sort of deep go deep dive into the like a old floating point handbook.
>> That's right.
>> Yeah. Like like like from ' 80s >> and you kind of see all these classic tricks.
>> I see.
>> Yeah.
But the broader trick seems powerful, right? Which is like you look at a GPU and it has like a bunch of functional units. You like look at which ones aren't being used very much and then you try to figure out can you just make can you paralyze them by using this and you get deeper pipelines. So like even though I think you only showed it for an X but I think the the broader trick that you described seems like a good frame for approaching uh performance problems on modern GPUs.
>> Yes. Yes. source maybe for like different functionalities or different architectures might be useful. Yeah, >> exactly.
>> So the last sort of the contribution I'm going to talk about for the forward pass is how to sort of reduce non-mathmonal operations in the correctional work group. So in the correction work group if you look at sort of the standard online softmax as we sort of iterating across the key value blocks uh we have the running max and the running normalizer at every step for the accumulated output we essentially have to scale with the following term and then at the end we apply the final normalization.
So this is can be costly because it's like a vector multiplication and ideally you kind of want to reduce this. So instead of sort of scaling whenever the new max jumps, what if we sort of tolerate some slack for lack of better expression? So we have sort of this tow factor where essentially if the max is less than certain threshold, we just sort of avoid scaling and then immediately jump send the signal to MMA and start accumulating the output.
Is is that deterministic?
Why would it not be deterministic?
>> Sorry, maybe determinist is the right not right word. It's deterministic because you're always ignoring the sort of the slack factors in the same way.
Like would you would you run into sort of weird numeric issues by doing this kind of stuff?
>> So it depends on sort of the factor. Uh this is more of an observational trick.
Uh if you use a larger factor that might become problem for BF-16 but >> so so what is thou typically set to in >> so it says to like eight which is like on the exponential like 2 to 8 is like 25 56 factor >> 56 factor. Okay.
>> Yeah.
So the difference is eight but the scaling is like the base 2 exponential.
>> I see. And then once you ignore it you ignore it. It's not like you're accumulating it for the future.
>> It's not. I see. But like at the end uh you still have access to the final max and the normalizer which you do apply.
>> I see. Okay.
>> So just to sort of quickly recap for pass has three contributions. On one hand we're trying to overlap the softmax and the MMA. On the other hand we have a software emulation that runs in parallel with the special function unit sort of mitigate the exponential bottleneck. And lastly, we have sort of this conditional online softmax rescaling sort of minimize non-math mode operations.
So before I sort of deep dive into the backward pass, let me give a quick background. So in the backward pass, we essentially sort of partition the key and value sequence length for let's say given batch and given head and we split it into tiles and we sort of map each block onto CTAs.
Each block essentially gets a unique keym value and it iterates across the query and the do which is the upstream derivative and essentially each CTA accumulates the gradient of KMV and V gets the final output copies it and then we have DQ where it's being atomically reduced at every iteration. So if you look at this figure uh this is like the inner loop zero all the CTS essentially participate to atomically reduce one and then next iteration next iteration nextation. So this sort of global atomic ad sort of introduces uh non-determinism which depends I mean if you care about in uh sometimes for debugging purposes or reproducibility people who care about this but Do you think you could get rid of it in another and flash attention 4.1?
>> So one approach is to have two kernels where one of them computes DK and DV and the second one computes DQ. The Triton version does this and the most recent QDN versions actually do this. Um but it will be really interesting to see if you can do this in a single kernel.
>> So in our paper we have sort of a version that does this. This was a semaphore and it sort of reaches about like the 75%.
So a bit of a performance drop but there is a deterministic mode in flash for backward pass >> and and I guess like before we go to four cuz like often most online tutorials like will only cover the flash attention like forward passes for whatever reason. Um, so could you walk us through maybe like the evolution from like one to four and kind of like what are sort of the important nuances here or like do you just view this as a sort of straight up extension of three in this case? So most like algorithmically speaking the backward pass hasn't improved that much. Um so the first trick was like the delta >> m >> which is s sort of makes it computationally feasible and mo most of like the backward improvements have been more of like a kernel op design choices as opposed to like algorithmically changing it. I think the algorithmic is maybe optimal up to this point. It would be really hard to like maybe try to accomplish it with like four gems instead of five.
But yeah but but we're totally right.
the the the the backward pass seems very like abandoned.
I don't know if it's I don't know if people don't care or it's maybe so much as in difficult to understand.
>> I think backwards passes are counterintuitive to explain >> counter. Okay.
>> Interesting. Okay. Yeah. So just to sort of look at the back pass from a different view is when you call a flash backward pass uh it's essentially like a wrapper function where under the hood it's calling like three kernels. So the first one is the pre-processing kernel that computes sort of this delta.
There's no gems. It's just element wise operations and we have the post-processing kernel as well which is also element wise operation which sort of converts the DQ from FP32 to BF-16.
So since DQ is being uh accumulated or reduced atomically uh we have to do this in FP32 for numerical precision reasons.
So at the end we have to sort of make this conversion. So these two kernels the pre-processing and post-processing take fraction of the time of sort of the main kernel that computes the five gems.
So an alternative to this is as we just discussed is you have two kernels instead of one. One of them computes DK and DV and the other one computes DQ. So this makes it deterministic but also you have to pay the cost of like kernel launch and you have to do seven gems instead of five. So it adds quite a bit of cost but much easier to implement >> and and and maybe more of a noob question like uh so these are like three different kernels like why not just like inline them all into like a single large kernel.
>> So like let's say if you like combine the pre-processing post-processing now you have to also load the O output.
So yes, >> so the with like limited SM you kind of want to avoid it's much more cheaper to just premputee this.
>> And then what about the type conversion in the end or or that's not a function though that's not a kernel. Oh, it is a kernel.
>> It is a kernel. So this is uh >> same intuition though, right? You just want to >> So since DQ is being atomically added, you you kind of have to do an FP32. You want to avoid doing it in BF-16.
>> Oh, sorry. My question was why not inline it into the previous kernel as well?
>> But but if you sort of inline in the previous kernel, you still have to like compute the output of DQ in like FP32.
>> You comput that.
>> Oh, I see. I see. I see. Be the output of DQ and FP32.
So for like DK and DV it's actually fairly straightforward because we can just sort of each CTA accumulates and writes one unique output and this happens once in the epilog >> and it can just be done in BF16.
DQ is like at every iteration you're basically writing uh atomically >> and he has to be in FP32 because do you're doing atomic ad you want to make sure numerical precision reasons.
Okay, so this is more of like a very high level overview of like how different warps specialize for different components of sort of the backward pass.
So I know there's like a lot of equations that might be like a little difficult to parse, but kind of my point is to say like how this different warps sort of in parallel try to do their job. So before the main iterations we have sort of the prologue phases and as we sort of load the Q and DO in every iteration we're trying to overlap with compute MMA and in parallel warp we have the the compute warp which is doing sort of the element wise operations for P and DS and we're trying to in parallel sort of compute the attention probabilities and the gradient of S. So we're trying to paralyze this.
Meanwhile, we have another dedicated warp group for DQ. And essentially every iteration the DQ MMA is on, it sort of tries to copy load from T-M to register from register to shared memory. And then essentially in stages it copies from shared memory to global memory.
So a few questions about this picture.
So first off like this to me seems like it reflects roughly what your process was like in designing. like you basically just have different columns be like different warps and then you try to draw arrows for like sync points. Would you think that's accurate?
>> That's spot on.
>> Spot on. Okay. Um and then of course like the second part of this picture I think like the the specific location of every like let's say a specific formula like the PJ equals X plus J + 1 minus LSC. Um you're only saying when that operation gets launched not like when it completes. So this is not showing overlaps for instance explicitly in this picture. Right? So this is sort of meant to show overlaps between different warps.
>> Oh, it it is showing overlaps.
>> It is showing overlaps. Yes.
>> Okay.
>> Yeah. So like we know like from earlier like the exponential can be like a bottleneck and it takes about the same time >> as like two gems for like right. So in parallel as we sort of compute the >> evaluate the exponential to comput the attention probabilities >> in in the parallel warp in the MMA we're doing the two gems.
>> Okay. So this is like fully over overlapped. Okay. And then and then then the other thing is like you mentioned here briefly in your annotation like loads from TM to RM. So I think the the you mentioned so so in TM like there's only only the PS and the DSS basically are there it looks like from your picture.
>> So every like gem output is on TM.
>> Every gem output >> every gem output.
>> Okay.
>> And DK and DV are sort of accumulated across the main loop. Right.
>> Okay. So those need to be like kept to themselves because you don't want to overwrite them.
>> And each one is let's say head dimension 128. So that's like DK and DV. That's sort of 256 columns gone.
>> So now you have sort of 256 columns to work with for the remaining items that are going to be on the gem. So that's D, that's S, that's DP, DQ, and also P.
>> So essentially you're overlapping all of this.
>> I see. And then like another new question is like you mentioned like T-M is like popular for relieving register pressure but it looks like loads are still done from T-M. So could you maybe speak a bit more about that? So the register pressure was like okay like for instance for every DQ here if this was like on flash nation 3 the DQMMA had to be done you have to wait there and then load and then sort of do any post-processing and if you're doing some post-processing let's say like scaling it you have maybe some intermediate values but here it's just like completely done in a parallel warp okay so once this MMA is done you can immediately start computing DP for next iteration.
>> Okay. Yeah. Yeah, that makes Oh, I see.
So, so then so then what is this point about? But what does this line TM to RM mean then?
>> Uh this is just sort of uh sort to say like the DQ warp which warp group is solely responsible for just like copying the essentially DQ from getting from DQ to register shared memory global in atomic way. So that that's the designated job of this warp.
>> I see. And then is is this like label is it comprehensive? Like can you load from RMM to TM? Can you store TM to RM like or or is like basically this is like comprehensive what you're showing here?
>> So anytime you want to load from TM, it has to be loaded onto register. But if you want to but you can copy from shared memory to TM.
>> You can copy from shared memor to TM.
But anytime you want so but anytime you want to come out of teamm you always has to be on registers.
>> Okay. And then oh I see this is like the order in which you get them out. I see.
It's like one two three. I see. Okay.
Got >> exactly.
>> Exactly.
>> Got it.
>> So on one hand this is like making the pipeline such that like the MMA and the softmax are overlapping and also we store some of this intermediates on TM and the reason for it being is like as as mentioned for the backward like it's bounded by the SM bandwidth. So we're trying to sort of relieve this pressure.
So now like if we have five gems each one has two operins uh two of them essentially are on tm and the remaining eight are on sm. So your longest column like the warp 12 whereas like the like basically like what I find interesting like I warp 0 to 3 has like one instruction like the sort of the one for loop and then warp 12 has this like mega algorithm. So how do you decide how many warps to allocate for each of these like pipeline stages?
>> Oh okay. So this is probably the most simplest warp actually >> cuz you have one thread that just calls MMAs >> and and there's a dependency between them. So it's like a very linear >> uh from the MMA to MMA there is and also like you have to remember cuz some of these depend on values where it's like being processed in the parallel. So like if this guy's waiting for like P, you have to make sure the P is done safely stored on teamm before he can sort of start doing the MMA.
>> Oh, I see. So, so practically speaking here like the way you'd go about deciding this is you just like let's say assume in the very beginning that just like one warp is doing each of these steps and then you try increasing them and you try to find like basically grid search like the best like allocation of like amount of warps to a specific uh pipeline stage.
>> So you have 16 warps per like CTA, right?
>> Yeah. You need one thread to issue the TMA, tensor memory accelerator. You need one thread. So this is each one needs one warp.
>> Now you have like uh 14 left.
>> TM has like 128 columns. So it's kind of like a warp group.
>> Uhhuh.
>> So you need four warps to sort of be able to load from all of them.
>> Okay.
>> So the compute needs at least four. The DQ needs at least four.
>> Okay. So then but then why do you have more than four for >> for the comput?
>> Yeah. So for the compute you do a lot of element wise operations.
>> Okay.
>> So there's a lot of intermediate values that need to be on the register.
>> Okay.
>> So as we allocate more work we allocate more registers and also this portion of the computation is like bottleneck by kind of in instructions. So we want some level of instruction level of parallelism.
>> Oh I see. So, so even like within a warp group like just by sort of saying this you're you're basically hinting at the compiler like I want some level of instruction level parallelism here.
>> Yeah.
>> Oh, interesting. I see. Okay.
>> Okay. So, so what happens like let's say each tile is like 128 by 128.
>> So the first warp group gets like 128 rows 64 columns. The second warp group gets 128 rows the second 64 columns.
>> So essentially each thread gets like 64 elements.
Got >> it. Okay. Yeah. Makes sense.
>> Okay. So even like with this pipelining and storing like P and DS on like tensor memory, it's still like bottleneck by shared memory bandwidth. So naturally we as a next step we tried the 2CTA MMA mode. And this may seem a bit unintuitive at first because like like you were used to CTA to like maybe like in the decoding step of the forward pass where like you're bottleneck by like KV cache like you have to load every step to reduce sort of the bandwidth but as we saw earlier like the backward pass is sort of bottleneck by SM bandwidth. So essentially we increase the M dimension of the MMA to 256 and we just keep the N 228. So this performs a much larger uh MMA. So with sort of keeping the end dimension of the Auburn B as sort of 128 each essentially uh tensor is being split in half.
So each CTA gets only half compared to like the one CTA counterpart.
So this reduces the footprint.
Well I should say this in quotation marks but but it does reduce the SM bandwidth. And the reason I'm saying in quotation marks like let's say if you look at like Q and Q here the transpose versus the non-transpose one >> the reduction axis for these two is different.
>> So like for instance this guy splits horizontally the other one splits vertically >> for like CTA zero. So there's one quadrant that's like overlapping between the two >> but not >> Mhm.
>> whereas like in the one CTA version you just need to load Q once and then for Q transpose you just recast the layout. So presum there's like a missing step here then like or no? What do you mean missing step?
>> Like are you like explicitly reshaping here basically like the like like the tile?
>> So we're not so like in the 1CA version like you let's say one you load the Q and then you just recast the pointer to like change the layout.
>> Oh okay. Sure. Here it's just like you separately load Q transpose and Q and it does reduce the SPM bandwidth because this is like what the MMA takes only half of it.
>> I see.
>> But the TMA loads each one.
>> I see.
>> Yeah. So you mean the each each of the CTAs will load one, right? Or the >> So you still have to load like the full T tile using the TMA from like global memory to shared memory. But when you're sort of feeding this into the MMA instructions, >> they essentially ect only needs half of it. It's respective half. Okay.
>> Yeah. So that that's why it reduces the SPM bandwidth.
>> Got it. Got it.
>> Yeah.
>> But this actually turns out to be quite problematic for the >> and why SM why not TM in this case? Like because this is not Q K and Q are never in TM, right? It's just like purely the output of uh it's it's like full output tensors. You never put individual tiles in TM.
>> So the output is always accumulate on TM. But you have the option for operant A to be like on TM.
>> Uh-huh.
>> So in this case uh P is on TM and the DS for DK is on TM.
>> Okay, got it. And then Q is in registers.
>> The so opra A can be either on TM or SM.
Yes.
>> But oper B has always has to be on shared memory.
>> Oh. always answer memory. Okay.
>> Always instrument >> and this is just like an Nvidia limitation for the this is this is the limitation of the TC5 instruction.
>> Yeah, it's a hard constraint on the so that makes sense. Yeah, >> it's kind of consistent with the hopper version where operand was either on register or shared memory. Oper was always on shared memory.
>> Okay, got it. Yeah, but this has been kind of problematic for the DQMMA where the reduction axis ends up being split between the two CTAs. See, and we know that in a case of like two CTA, you have to sort of partition the M dimension and the N dimension, the non-reduction axis between the two CTAs. And also it's worth mentioning that once you launch the kernel like you set the CTA group size all the it has to be consistent across all the MMAs. So you can sort of mix and match and pick and choose.
>> Oh I see.
>> Yeah. Or so if all of these are running with like two CTA mode the last one has to always on the two mode.
>> Yeah.
>> Within an individual kernel loss. But if it was a separate kernel Yeah. Within Yeah. Exactly. Exactly.
Okay. So the way sort of go about solving this is if you look at sort of the C portion of this figure essentially we want to partition M of DS. So for K we can just because it's a it's something that's on global memory can sort of take that specific tile form we want but for DS is a sort of intermediate value that is computed and the idea is if you like look at the figure for D portion we want to exchange the top half of CTA 1 with the bottom half of CTA 0.
So once we sort of do this exchange essentially we're splitting the m dimension but we're doubling the reduction >> what's the intrinsic goal for CTAs to exchange data with each other in this way >> the the what >> like is there like an intrinsic where you to make CTAs communicate data to each other >> so that's a good question so the ways so we do this we use like distributed shared memory >> so where by definition the two CTAs are in the same cluster so they can sort of access each other shared memory, right?
Yeah. So, we sort of leverage that.
>> And so, now instead of like the the four gems have a M dimension of 256, but the DQMMA has sort of M dimension of 128, but he has the double the reduction axis. M >> so at the end instead of like each CTA writing like a full 128x 128 tile now each one writes like 64 by 128 >> because you're sort of locally doing a larger reduction >> and is there like at the very end like yet another like global reduction or that's it like once the each of the partial reductions is basically >> so still every iteration like you have to each CTA has to write uh 64 by 128 >> but that's You just write it.
>> Yeah. Every iteration. Yeah.
>> Whereas in the 1CA case, it's like >> doing 128 by 128.
>> Yeah. Okay. So, at first glance, it kind of appears that we're doing this to sort of enable 2CTA mode for DQ as well in order to reduce SPM bandwidth. But this turns out to be having other benefits as well.
So on one hand we're sort of reducing the SPM bandwidth because sort of the operant B of the following MMAs essentially need to stage and load half of it right. So now that each DQ has okay uh like a smaller output tile.
This reduces like the glo global atomic ad and this is like being done every iteration for FP32 and this tends to like boost the performance quite a bit and also now that DQ is like smaller and it's being overlapped on TM with other tensors we can sort of improve the overlapping and sort of we change the ordering of the software pipelining. It improves uh the overlap and also it's beneficial like in head dimension 192 case >> this like the deepse shape right >> yeah the MLA exactly exactly okay so in the next two slides I'm going to talk about some scheduling where if you look at uh causal attention there naturally it has a load balance and also the fact that sort of EC CTA gets assigned. Let's say CTA 0 gets worker one. CTA 2 gets worker two. Uh this is like the shortest time processing first.
>> What's a worker here? Is it a warp group?
>> Uh like an SM >> SM. Okay.
>> Yeah, we sort of improve this using like the longest processing time first.
And essentially we iterate in a reverse order and also we sort of process it like in head sections and swizzle the head count and it sort of helps with like the L2 cache rate in terms of some results. So since like QDN 9.13 9.14ish we started sort of collaborating with them where we sort of exchange ideas. So some of these methods for instance the software emulation has been used in uh QDN 94 and onward. Uh so that's why we sort of benchmark both of them. So compared to like the 9.13 version uh FA4 is consistently sort of outperforms and with the most recent version 9.19 this came out I think a few weeks ago the performance is generally like on par for the forward pass. So, uh, flash attention 3 is under this picture because it crashed because it used intrinsics that aren't on Blackwell.
>> Exactly. Exactly.
>> And then, um, it also seems like for KDNN and FA4 like your measurements really get like within noise of each other and maybe they beat you a bit on like lower sequence lengths maybe >> uh, compared to the latest version. Yes.
Yeah.
>> Yes. And then last question is like for for glue on specifically uh what set of tricks did they use? Like did they also use this like softmax tricks? Did they use like other things?
>> So they don't use a softmax trick. Uh but glowing gives you definitely more lower level control like or specialization and so on.
>> So the so they definitely have like uh better overlap compared to like the Triton version.
>> Yeah.
And then like though um how do you find the readability of all these implementations to compare?
>> Um I well Q is close to second.
>> We don't know. Yes, we can guess.
>> Yeah. Um for the Gluon and Triton I would say it's I mean it's far more easier to read compared to QDSL >> especially Triton. But one thing I'll say is like for Gluon like even though like it's not performing as well for the lines of code you write to get this this level of performance I would say it's it's quite impressive.
>> It's like significantly small shorter lines of code compared to like the QDSL versions.
>> Is it just like uh cuz but is it like as low of a level language >> as QTSL? Oh no.
>> Oh no. I think it's more closer to Triton. Thank you.
>> Oh, is it like Triton with with warp specialization basically?
>> And many more other features as well.
>> Oh, I see. Interesting.
>> Oh, so it's like an extension point.
It's not like really new. Okay, I see.
>> It's like a back end the sort of call right. Yeah, >> I see.
Here are some results for like head dimension 192. These are like deepse MLA models. Essentially, the FA4 is like consistently better across especially larger sequence length in a causal case.
some backward results. uh the results I'm reporting here all are just using 2CT MMA mode and just consistently from let's say 4K and higher uh FA4 is sort of outperforming the benchmarks and as you can see like most of this implementation for Triton there's no like backward kernel that you can really benchmark so we sort of skew the NNN and FA2 in a case of like sort of head dimension 192 it's a bit like mixed results uh in the causal case uh is slightly doing worse than QDNN and but like this is a very recent uh effort but it's also worth mentioning like the most recent QDN version is using two kernels and FA4 is using a single kernel and in like a two kernel approach where you have se one kernel computes DK and DV the other one computes DQ it would be much easier for like deal with the head dimension 192 where like on register and also on on team at my shared memory.
>> Was that the main reason they did it or was it because of some of the determinism stuff that you referred to earlier?
>> Could be that they were motivated for the determinism, but like I can't really tell cuz the code is like closed source.
>> Yeah, I understand. Yeah.
>> Yeah.
>> But the way I know it's like two kernels is like if you profile like it prints like the name of the all the kernels and call like I can't I can see.
>> Yeah. But but the earlier versions didn't do this.
So here the result for the backward pass for head dimension 192 it's really difficult to do in a case of like single kernel. Mhm.
>> I mean think about it. So like if you have DQ and DK right that's like each 256 sorry 192.
>> Mhm.
>> That takes most of the TM but you have all these different elements.
>> So you just wish you had more TM basically.
>> Yeah. And the only way we can do this with like uh this approach on a single kernel is because of the two CTA approach as discussed earlier.
>> So now that DQ is smaller, it's it's easier to like overlap with one CTA mode. It will be it will be catastrophic.
>> Oh, I see.
>> It was serialized so much.
>> I see.
>> Yeah.
Yeah. So there's also like kind of half the story of Flash 4. On the other hand, it's really easy to build and also to extend to new variants. So it's entirely written like in QDSL in Python. So this gives you like full low-level control.
If there's like a feature that's missing, you can literally write an inline PTX and get it there. Like for instance, the DSM expand for the 2CTA.
Yeah, you like we have to write like a PTX inline and but generally it's like 20 to 30 times faster to compile. So this naturally sort of helps you to like debug faster, try more radical ideas every now and then and it it definitely makes the development process much faster and the flash 4 exposes many new like functionalities as as composable primitives. So it supports like block sports patterns, masking strategies of tile scheduling. So this helps to like if you have like a new attention variant you want to try you just like just call this abstractions.
One example of this is like flex tension where uh researchers at meta and cfax has sort of taken advantage of this to build flex tension on top of flash tension 4 and you can see like they're getting very impressive results.
>> Shout out to my man Dris.
>> Yeah D. Yeah. I'm going close the talk with like the main message which is the bottleneck shifts for tension on black wheel from just being pure MMA to like exponential for the forward pass and share memory bandwidth for the backward pass. And in the forward past we sort of addressed this with like overlapping MMA with softmax and then boosting the exponential throughput with a software emulator version that's running in parallel with the SFU and then skipping some rescaling steps. In the back class we talked about the 2C MMA mode and using the T-M for some of the intermediate results. So right now the code is open source and we're sort of working to integrate with popular libraries and the goal is to sort of extend some of these ideas beyond just simple attention and also some of these ideas can be translated to other hardware where like compute is outpacing the other functional units. Yeah, like for me that last sentence was like one of the main takeaways here which is like turn like figure find the dead compute units on your hardware or turn more things to matt molds.
>> Yeah, >> generally are two like very useful frames.
>> Exactly.
>> Ted, thank you so much for doing this.
You're always welcome back and thank you so much for doing our very first inerson lecture. You're always welcome to come do a digital lecture with us as well. Uh and yeah, we'll chat some more. So, thank you for coming. Thank you for listening everyone. Appreciate it.
>> Thank you, Mark. This was a lot of fun.
Thank you.
Related Videos
Agentforce NOW AMA: Build with React and Salesforce Multi-Framework
SalesforceDevs
490 views•2026-05-28
How agent o11y differs from traditional o11y — Phil Hetzel, Braintrust
aiDotEngineer
450 views•2026-05-28
WEB TECHNOLOGIES UNIT-2 | Degree 4th sem BCOM Computers web technologies unit-2 full explanation💯✅
LearnwithSahera
1K views•2026-05-29
More tests are always better? How to use AI to identify tests that bring little value
Alliance4Qualification
335 views•2026-05-29
Search Algorithms Explained in 60 Seconds! 🤖💨
samarthtuliofficial
218 views•2026-06-01
People of Game of Thrones using JavaScript DOM
AltCampus
296 views•2026-05-30
Introduction to Problem Solving Part - 1 | Lecture 1 | Intermediate DSA
ascensionix
107 views•2026-05-29
So What's Odin Lang Even Good For
TechOverTea
131 views•2026-06-01











