Bypass LDS for scale B operand for skinny gemms#817
Conversation
| mlir::triton::LinearLayout scaleBLayout = | ||
| mlir::triton::gpu::toLinearLayout(scaleBTy.getShape(), | ||
| scaleBTy.getEncoding()); | ||
| bypassLDS = bypassLDS || |
There was a problem hiding this comment.
What is this doing here? Is it checking if bypassing LDS succeeded?
There was a problem hiding this comment.
I think @plognjen wanted to restore the previous condition, i.e. width < 32 should bypassLDS.
If this is the case, maybe we can use another variable to store the value of (width < 32) rather than bypassLDS to avoid any confusions.
There was a problem hiding this comment.
yes, this was to restore the previous condition. I will change the name.
| loadInfo.usedByDot = true; | ||
| // If the max continugous bits we can read is < 32, buffer in registers. | ||
| if (width >= 32) { | ||
| bool bypassLDS = width < 32; |
There was a problem hiding this comment.
So, we're only bypassing LDS when the we're loading smaller than dword, such as buffer_load_short or buffer_load_ushort?
Are there other cases when bypass LDS could be beneficial? If so, let's add a comment reminding us of those additional scenarios.
There was a problem hiding this comment.
Due to preshuffling, width is guaranteed to be >= 32. Therefore, it's confusing to enable bypassLDS only when width < 32.
More generally, bypassLDS should not check width. Later it checks if the loaded layout is the same as the scale layout, and this makes sure width = 32.
Skip LDS for the scale B tensor when warpsPerCTA is {1, numWarps} and
the load layout matches the expected layout for scale B in the dotScaled op.