Virtual Dense MFMAs for Skinny GEMM
When we have a GEMM A * B = C, and it is the situation that A has a
small number of rows and many columns, we classify this problem as a skinny
GEMM. The decode phase of LLM inference is a common sight of this problem: a
small batch of tokens multiplies against a large weight matrix. Skinny GEMMs are
less convenient for modern GPU architectures than their non-skinny cousins. One
reason is because modern GPUs take advantage of matrix core units which offer
instructions that are specifically designed for matrix multiplication and
operate on fixed tile sizes, and skinny GEMMs are too small to utilize them to
their intended size.
On AMDGPUs and in particular on the MI3XX Instinct (CDNA) series, these
instructions are known as MFMA instructions; for example,
V_MFMA_F32_16x16x16_F16. One useful part of the name is the MxNxK
tile shape consumed, where M is the number of rows of the left hand matrix,
N is the number of columns of the right hand matrix, and K is the shared
dimension of both.