Let's talk about Nvidia Hopper's new feature WGMMA

In-depth exploration of Nvidia Hopper GPU's new matrix multiplication operation WGMMA, mastering the efficient GEMM kernel design.
Core content:
1. Introducing the WGMMA instruction of the Tensor Core under the Hopper architecture
2. Analyzing the impact of the WGMMA instruction on the GEMM algorithm
3. Exploring the efficient implementation of WGMMA in the CUTLASS library
Last time, we introduced the new feature TMA on Hopper. This time, let’s take a look at the new matrix multiplication operation WGMMA on Hopper.
Introduction
A CUDA tutorial would not be complete without a chapter on general matrix multiplication (GEMM). GEMM is arguably the most important routine on modern GPUs, constituting the bulk of computation in neural networks, large language models, and many graphics applications. Despite its ubiquity, GEMM is notoriously difficult to implement efficiently.
This three-part tutorial series is designed to provide readers with a comprehensive understanding of how to write efficient GEMM kernels on NVIDIA Hopper GPUs using the CUTLASS library.
[Part 1 of this section ] discussed the warp group matrix multiply-accumulate (WGMMA) instructions. These are the native instructions for the Tensor Cores of NVIDIA GPUs based on the Hopper architecture. [Part 2] will discuss the overall design of an efficient GEMM kernel, including advanced techniques used in the CUTLASS kernel, such as warp specialization and ping-pong scheduling. [Part 3] will discuss persistent kernels and Stream-K, a load-balancing strategy for GEMMs that achieves state-of-the-art efficiency across a wide range of problem geometries.
The three parts of this series roughly follow the entire development of the General Matrix Multiply (GEMM) kernel, but in an "inside-out" manner. First, we have the basic operation of GEMM in blocks, which calls Tensor Cores to ultimately perform the computation. Second, we have the design of the GEMM kernel from the perspective of each warp Cooperative Thread Group (CTA) - consisting of a prologue, main loop, and epilogue - where the main challenge is to avoid memory loads becoming a bottleneck for the fast Tensor Cores. Finally, we have the scheduling of CTAs at the outermost grid level, where load balancing considerations become primary.
We hope that after reading this series, readers will become experts in the GEMM algorithm and be able to use some of the good ideas in the algorithm to design and implement other kernels in their own work.
Asynchronous Warpgroup MMA (WGMMA)
Hopper introduces asynchronous warp group-level matrix multiply-accumulate operation (WGMMA). A warp group consists of four consecutive warps, that is, 128 consecutive threads, where the warp number of the first warp is a multiple of 4.wgmma.mma_async
The instruction is executed collectively by all 128 threads in the warp group. This operation typically takes one of the following forms, where the matrix C is used as the accumulator:
C = A * B + C
C = A * B
, where the input from accumulatorC
is disabled.
A notable requirement of WGMMA is that operand B must always be stored in shared memory (SMEM). In contrast, operand A can be located in either shared memory or register memory (RMEM), and the accumulator C is always kept in RMEM.
This blog post is structured as follows. First, we discuss callingwgmma.mma_async
This involves building the relevant TiledMMA, and creating and partitioning SMEM tensors to be compatible with WGMMA. Second, we discuss the synchronization mechanisms required to ensure the correctness of WGMMA. Finally, we discuss the layout used in WGMMA in more detail, including the concept of core matrices and matrix descriptors for operands from SMEM.
Throughout this procedure, for the sake of brevity, we willwgmma.mma_async
Abbreviated aswgmma
.
WGMMA in CUTLASS kernel
In this tutorial, our main goal is to explain the wgmma primitive used to call Hopper Tensor Cores for block-based GEMM and how to call it as part of a cute::gemm call. To set the stage, consider a standard GEMM kernel that takes a shape ofmxDV
The input matrices A and B are calculatedC = A*B
To parallelize the computation, the kernel fixes the static block sizes bM, bN, and bK and starts a⌈M/bM⌉
x⌈N/bN⌉
Multiple thread blocks (CTAs), each CTA computes a bMxbN tile of the output matrix rC. This is stored in the CTA's local memory (RMEM) before being written back to the global C matrix.
Following the CTA, we have the main loop of our kernel. Over multiple iterations, we loop over the inner dimensions and load bMxbK and bNxbK blocks of A and B, respectively, from global into shared memory as sA and sB; note that in CUTLASS we fix the shape of sB to be the mathematical transpose. (In fact, reflecting common practice, we load blocks of A and B into a circular MEM buffer where the number of stages is given by a compile-time integer, such as 2 or 3. The final mode of the shape tuples of sA and sB is then given by this stage count.)cute::gemm
The call then computes the product of the (staged slices of) sA and sB and accumulates the values successively into rC. After the main loop is complete, rC is finally written to global memory.
We now wish to interpret the following cute::gemm call and its arguments.
template <class TiledMMA, ... >
__global__ device_gemm(TiledMMA tiled_mma, ...) {
// PROLOGUE
// ...
// Define A/B partitioning and C accumulators
ThrMMA thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate accumulators and clear them
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
clear(tCrC);
// Allocate "fragments"
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
// PIPELINED MAIN LOOP
while (k_tile_count > -K_PIPE_MAX) {
// ...
//MMAs to cover 1 K_TILE
cute::warpgroup_arrive();
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);
cute::warpgroup_commit_batch();
// Wait for all MMAs in a K_TILE to complete
cute::warpgroup_wait<0>();
// ...
}
//EPILOGUE
// ...
}
In the MMA (matrix multiply-accumulate) paradigm of CUTLASS, the cute::gemm method in the "MMA paradigm" is intended to expose the MMA instructions of a particular architecture through a unified interface. (In fact, if you look at the GEMM kernel in the SM80 tutorial, you will see that therecute::gemm
The call is syntactically identical to above.) However,cute::gemm
The parameter definition involved in the call contains many WGMMA specific aspects:
The definition of the TiledMMA object tiled_mma encapsulates the information needed by cute::gemm to dispatch to a specific wgmmaPTX instruction. The layout of the SMEM tensors sA and sB must be defined to be compatible with wgmma. Snippet ikB
,ikB
andtC
Use of TiledMMA objects is built as thread-level partitioning of the data and therefore has a WGMMA-specific layout that the programmer should be aware of.Snippet ikB
(If fromSMEM
Get operandsA
)andikB
is not a register-backed tensor whose value is obtained fromSMEM
Copied, but inSMEM
Matrix descriptor built on top of .
Finally, of course,cute::gemm
There are thread group synchronization primitives around the calls. We will explain all these concepts in turn.
TiledMMA Object in WGMMA
In the following, we assume that the data type is FP16, and A and B are MN, so in BLAS notation, we are computing a NT gemm
We usecute::make_tiled_mma
The method constructs a TiledMMA object on the host as follows:
TiledMMA tiled_mma = cute::make_tiled_mma(
SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});
Althoughcute::make_tiled_mma
There are some optional parameters as well, but let's focus on the current one - the matrix multiply-accumulate atom (MMA Atom). This is a structure that wraps an underlying PTX call, in this case:
wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16
CUTLASS symbology allows one to immediately read the relationship between the packed PTX instructions and the MMA atoms. First, SM90 is another name for the Hopper architecture. Then, the SM90 MMA atoms are labeled SM90_MxNxK_XYZ_SS or SM90_MxNxK_XYZ_RS, where there are two template parameters that can beGMMA::Major::MN
orGMMA::Major::K
. Their meanings are as follows:
X and Y are the data types of the operands. Z is the data type of the accumulator. MxNxK is the size of the wgmma instruction calculation - the "wgmma atom". Not all values of MxNxK are mma-able. Here is a list of allowed shapes: M is always 64, N is a multiple of 8 from 8 to 256, and K is 16 for 16-bit operand data types (more generally, K is fixed to 32 bytes). The suffix RS or SS indicates whether operand A comes from a register (R) or from shared memory (S). Operand B always comes from shared memory, hence the S. These two template parameters indicate whether operands A and B are memory-contiguous in MN mode or K mode. For example, in BLAS notation, both operands being K-Major would correspond to a TN gemm (see this table). Note that for 16-bit operand data types, the memory layout can be flexibly set to MN-Major or K-Major. However, for non-16-bit operand data types, the layout must always be K-Major.
That's all you need to know about the syntax of the MMA Atom! Now, we've stressed that WGMMA is a full thread group directive. In code, you can use its size to retrieve the number of threads participating in the MMA operation defined by the TiledMMA object. For example, the following host code.
dim3 dimBlock(cute::size(tiled_mma));
Each CTA in the kernel is specified to be launched with 1 warp group of 128 threads. Suppose we want 2 warp groups to perform WGMMA, with different warp groups independently computing half of the output block (and each warp group issuing its own wgmma instructions). To do this, we can pass a non-trivial layout (AtomLayoutMNK) as the second parameter to the make_tiled_mma method. For example, the following code.
TiledMMA tiled_mma = make_tiled_mma(
SM90_64x64x16_F16F16F16_SS{},
Layout<Shape<_2,_1,_1>>{});
A WGMMA operation is defined, where warp groups 1 and 2 compute the upper and lower halves of the output tile, respectively, divided along the M pattern (assuming for now that bM is a multiple of 128). Additionally, size(tiled_mma) will be equal to 256.
In general, the two optional layout parameters of make_tiled_mma - AtomLayoutMNK and PermutationMNK - apply equally to any MMA atom.
The layout of shared memory constrains WGMMA
Next, we explain the constraints on the tile size and layout of the operand matrix in shared memory given the choice of MMA atomics. First, for any MMA instruction, the MxNxK of the MMA atomic needs to be able to divide the operand and accumulator tile size. In our case, this means that bM should be a multiple of 64, bN a multiple of 64, and bK a multiple of 16.
Second, WGMMA imposes an additional constraint on the shared memory layout (including shape and stride) of sA and sB, and this constraint varies with the chosen interleaving pattern. In particular, the layout of (stage-sliced) sA is usually not simply (bM,bK):(1,bM) or (bM,bK):(bK,1), and the same is true for sB.
To understand these requirements in depth, we need the concept of a "core matrix", which we will introduce below. However, in practice, we can always use some of the predefined layout atoms provided by CUTLASS and then use the cute::tile_to_shape method to construct a shape that is guaranteed to be the same as
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int<64>{};
auto bP = Int< 3>{}; // Pipeline
auto sA = cute::tile_to_shape(
GMMA::Layout_MN_SW128_Atom<T>{},
cute::make_shape(bM, bK, bP)
);
auto sB = cute::tile_to_shape(
GMMA::Layout_MN_SW128_Atom<T>{},
cute::make_shape(bN, bK, bP)
);
Here, MN means that the layout atomic applies to MN major operands, and SW128 is the 128-byte interleaved mode. Output sA or sB will show this.
Sw<3,4,3> o smem_ptr[16b]( unset ) o ((_64,_2),(_8,_8),_3):((_1,_512),(_64,_1024),_8192)
Where does this layout come from? cute::tile_to_shape takes a layout (a tile of the same name) and replicates it to tile over a larger shape (similar tonumpy.tile
). Leaving aside the swizzle function Sw<3,4,3>, we know that the layout atomically is given by (64,8):(1,64), and tiles in column-major fashion over the shape (128, 64, 3), so for the MxK shape, the smaller outer stride of 512 is in M mode, and the larger outer stride of 1024 is in K mode. (The largest stride of 8192 is in stage count P mode, which makes sense since different staged slices of sA or sB shouldn't be mixed in memory.)
Note that 64 times sizeof(half_t) equals 128 bytes, which is the name of the swizzle pattern. This is by design: due to the way core matrices work, we always lay out the length of the atoms in the contiguous direction to equal the number of swizzle bytes - either 16 for no swizzle, or one of 32, 64, or 128.
In contrast, if we consider:
auto sA = cute::tile_to_shape(
GMMA::Layout_K_SW128_Atom<T>{},
cute::make_shape(bM,bK,bP)
);
auto sB = cute::tile_to_shape(
GMMA::Layout_K_SW128_Atom<T>{},
cute::make_shape(bN,bK,bP)
);
Printing sA gives us the expected result.
Sw<3,4,3> o smem_ptr[16b]( unset ) o (_128,_64,_3):(_64,_1,_8192)
Since we are tiling (8,64):(64,1) on (8,64):(64,1) instead. (Note that the layout ((_8,_16),(_64,_1),_3):((_64,_512),(_1,_0),_8192) is merged into (_128,_64,_3):(_64,_1,_8192)).
In general, we can choose among eight possibilities for laying out atoms, corresponding to MN or K majority and one of four shuffle modes:
No Interleave: No interleave. 16-byte boundaries are implied. 32-Byte Interleave: Interleave 2 consecutive 16-byte segments. 64-Byte Interleave: Interleave 4 consecutive 16-byte segments. 128-Byte Interleave: Interleave 8 consecutive 16-byte segments.
GMMA::Layout_MN_INTER_Atom<T>
GMMA::Layout_MN_SW32_Atom<T>
GMMA::Layout_MN_SW64_Atom<T>
GMMA::Layout_MN_SW128_Atom<T>
GMMA::Layout_K_INTER_Atom<T>
GMMA::Layout_K_SW32_Atom<T>
GMMA::Layout_K_SW64_Atom<T>
GMMA::Layout_K_SW128_Atom<T>
These layout atoms must then be passed intotile_to_shape
, where the shared memory (SMEM) shape of sA and sB is given bymake_shape(bM,bK,bP)
ormake_shape(bN,bK,bP)
Given, the pattern of shapes is given in this order so that the tile size of the layout atom divides the tile size of the larger SMEM shape. This is ultimately a constraint imposed on the SMEM shape by the choice of shuffle pattern, and is separate from another constraint imposed by the matrix multiply-accumulate (MMA) atom shape.
WGMMA fragments and descriptors
We created the TiledMMA object and prepared the shared memory (SMEM) layout accordingly on the host. Now, on the device, we can use the TiledMMA objecttiled_mma
to construct the appropriate partitioned tensor to pass tocute::gemm
First, we passtiled_mma
Calling the thread indexget_thread_slice
Method to create athr_mma
In our case, the threads are indexed from 0 to 127.
Next, referring to the kernel code snippet above, printing the tensors tCsA and tCsB for any thread index shows the following:
tCsA: Sw<3,4,3>_smem_ptr[16b](0x7f8800000400) o
((_64,(_8,_2)),_2,_4,_3):((_1,(_64,_1024)),_512,_2048,_8192)
tCsB: Sw<3,4,3>_smem_ptr[16b](0x7f880000c400) o
((_64,(_8,_2)),_2,_4,_3):((_1,(_64,_1024)),_512,_2048,_8192)
According to the comments, the shape of tCsA should be considered as (MMA, MMA_M, MMA_K, PIPE):
MMA is the NxK shape of the MMA Atom. MMA_M and MMA_K are the ranges over which it tiles over the M and K modes of sA (so MMA_M = bM/64 = 2 and MMA_K = bK/16 = 4). PIPE is the number of stages.
tCrA: GMMA::DescriptorIterator o (_1,_2,_4,_3):(_0,_64,_256,_1024)
tCrB: GMMA::DescriptorIterator o (_1,_2,_4,_3):(_0,_64,_256,_1024)
Internally, CUTLASS constructs a "matrix descriptor", which is a 64-bit value held in registers that describes shared memory (SMEM) in a way suitable for use by wgmma instructions. The most important thing for programmers to remember is that shared memory values are not copied into register memory (RMEM); instead, accessing the values of tCrA and tCrB actually accesses these 64-bit descriptors. Additionally, these tensors acting as "iterators" means that at any time, for a given wgmma instruction, only one 64-bit descriptor is held in registers (as opposed to all 24, for example).
Compared to operands, accumulator tensors are defined in a more standard way. Printing tCgC and tCrC of thread 0 shows:
tCgC: gmem_ptr[16b](0x7f877a780000) o ((_2,_2,_8),_2,_2):((512,_8,4096),_64,32768)
tCrC: ptr[16b](0x7feee1fffbe0) o ((_2,_2,_8),_2,_2):((_1,_2,_4),_32,_64)
tCgC is the portion of the output GMEM tensor to which we want to copy the accumulator values in the epilogue, while tCrC is the register-based tensor created to hold these values while they are computed in the main loop. The (MMA, MMA_M, MMA_N) shape of these tensors can be interpreted as follows: In the MxN=64x64 output block of the MMA atomic, each of the 128 threads holds32=2*2*8
values, andMMA_M=MMA_N=2
Same as tCsA and tCsB.
Each thread holds 32 values of the atom in a way that requires decomposing 32 into a (2,2,8) shape in order to be able to define the corresponding stride for the layout of tCgC. The exact partitioning scheme can be read from this picture taken from the PTX documentation:
Gemm call
Let's go back to line 25 of the kernel code snippet above:
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);
The various overloads of the cute::gemm method are first used to loop over the outer modes MMA_M/N and MMA_K. Once these coordinates are chosen, we perform the computation using the matrix multiply accumulator atomic tile shape. In other words, we first simplify it to an overload of cute::gemm that dispatches the shape (V)x(V)=>(V).
The code then calls the fma operation of the matrix multiply accumulator atomic (to be exact, in the matrix multiply accumulator unpack (mma_unpack)). Here is some PTX assembly code:
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t& d00, uint32_t& d01, uint32_t& d02, uint32_t& d03,
uint32_t& d04, uint32_t& d05, uint32_t& d06, uint32_t& d07,
uint32_t& d08, uint32_t& d09, uint32_t& d10, uint32_t& d11,
uint32_t& d12, uint32_t& d13, uint32_t& d14, uint32_t& d15,
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
{
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15},"
" %16,"
" %17,"
" p, %19, %20, %21, %22;\n"
"}\n"
: "+r" (d00), "+r" (d01), "+r" (d02), "+r" (d03),
"+r" (d04), "+r" (d05), "+r" (d06), "+r" (d07),
"+r" (d08), "+r" (d09), "+r" (d10), "+r" (d11),
"+r" (d12), "+r" (d13), "+r" (d14), "+r" (d15)
: "l" (desc_a),
"l" (desc_b),
"r" (int32_t(scale_D)),
"n" (int32_t(scaleA)),
"n" (int32_t(scaleB)),
"n" (int32_t(tnspA)),
"n" (int32_t(tnspB)));
#else
CUTE_INVALID_CONTROL_PATH(
"Attempting to use SM90_64x64x16_F16F16F16_SS "
"without CUTE_ARCH_MMA_SM90A_ENABLED" );
#endif
}
The PTX documentation for this syntax is here. Consistent with the description of the tensors tCrA, tCrB, and tCrC above, note that we have uint64 variables desc_a and desc_b for the operands, and 16 uint32 variables for the accumulators. scale_D is either 0 or 1, and controls whether the accumulators are zero-initialized.
In addition, the variables scaleA, scaleB, tnspA, and tnspB are determined at compile time through template parameters outside the fma method. The values of scaleA and scaleB are 1 or -1, which are used to negate the operands; while tnspA and tnspB indicate whether to transpose the operands. When the value is 0, it corresponds to GMMA::Major::K
, when the value is 1, it corresponds to GMMA::Major::MN
.
Synchronization of WGMMA
Next we need to explain the synchronization primitives around the cute::gemm call:
cute::warpgroup_arrive();
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);
cute::warpgroup_commit_batch();
cute::warpgroup_wait<0>();
Why are these extra instructions necessary at all? It all has to do with the nature of wgmma as an asynchronous instruction. In the context of the Hopper architecture, "asynchronous" means that wgmma can run concurrently with other operations, so for steps that have dependencies, a synchronization mechanism is required. This mechanism is detailed in the PTX memory consistency model. Improper synchronization in your code can lead to the following: (a) subtle race conditions that can cause tricky bugs, (b) the compiler will execute wgmma instructions sequentially, which can cause significant performance degradation, or (c) undefined behavior.
The cute method encapsulates the following PTX instructions:
cute::warpgroup_arrive()
—wgmma.fence.sync.aligned
;cute::warpgroup_commit_batch()
—wgmma.commit_group.sync.aligned
;cute::warpgroup_wait<N>()
—wgmma.wait_group.sync.aligned N
;
(Note that we have been using wgmma as a shorthand for wgmma.mma_async, but only in this section will we make the distinction between the two.) Let's relate the use of these instructions to the following description of WGMMA-based general matrix multiplication (GEMM), quoted verbatim from the PTX documentation:
Load matrices A, B, and D into registers or shared memory. wgmma.fence
Operation to indicate that the registers/shared memory of the entire warp group have been written.fence.proxy.async
Operations that make common agent operations visible to asynchronous agents.use wgmma.mma_async
The operation initiates an asynchronous matrix multiply and accumulate operation on the input matrix.wgmma.mma_async
Operations are performed in asynchronous proxies.Create a wgmma group and use wgmma.commit_group
The operation will complete all the unfinishedwgmma.mma_async
The operation is submitted to this group.use wgmma.wait_group
Wait for the required wgmma group to complete the operation.Once the wgmma group has completed its operation, all wgmma.mma_async
The operation has been completed.
We explain these points in order. First,wgmma.fence
The instruction ensures wgmma.mma_async
Certain register memory (RMEM) addresses are accessed only after all previous accesses to those addresses have completed. wgmma.fence
, its behavior is undefined. An exception to this rule is that the Hopper architecture allows multiple wgmma.mma_async
Instructions. wgmma.mma_async
If the accumulator shapes of the instructions are the same, they can share the same accumulator tensor, that is, write to the same register memory address. In this case, no synchronization (fence) operation is required. For example, in cute::gemm
In the call MMA_K
When looping, we don't need to insert wgmma.fence
.
As with the Tensor Memory Access (TMA) operation,wgmma.mma_async
is executed in an asynchronous agent. Therefore, if an operation performed in a general agent affects wgmma.mma_async
To read shared memory (SMEM), we need to issue fence.proxy.async
For example, if we pass the normal ld.global
/st.shared
This occurs when the operation copies matrices A and B into shared memory. Since we use TMA loads, this is not necessary in our example. fence.proxy.async
, and indeed, it does not appear in the main loop of the WGMMA tutorial code or the CUTLASS Hope Architecture General Matrix Multiplication (GEMM) kernel. (To verify this, note that fence.proxy.async
Is cutlass::arch::fence_view_async_shared()
Encapsulated. )
wgmma.commit_group
The instruction creates a new wgmma
group, and execute threads initiated by the warp group but not yet submitted to any wgmma
All previous wgmma.mma_async
Instructions are batched into this new wgmma
In our example,cute::warpgroup_commit_batch()
Will MMA_M * MMA_N * MMA_K
strip wgmma.mma_async
Instructions are batched into one wgmma
group.
Finally, with parameter N wgmma.wait_group
The instruction causes the executing thread to wait until the nearest wgmma
The number of outstanding threads in the group does not exceed N, and all previous threads submitted by the executing thread wgmma
In our example, we set N to 0 so that the warp group only needs to wait for the entire wgmma
The group is completed before continuing to execute subsequent instructions.
The flexibility of parameter N comes in handy in cases where groups of warps have the opportunity to perform independent computations. For example, the GEMM-softmax overlapping strategy adopted in the design of FlashAttention-3 takes advantage of this.
WGMMA Core Operations
This last section will further discuss the block layout requirements for matrices A and B loaded into shared memory (SMEM), assuming that wgmma
Both operands of come from shared memory. To simplify the discussion, first assume that A is stored in row priority and B is stored in column priority (that is, both are stored in K priority). Also remember thatwgmma
The instruction block shape is restricted to MxNxK, where M is 64, the data type size multiplied by K is 32 bytes, and N is a multiple of 8, ranging from 8 to 256. To avoid confusion with A/B or sA/sB, we record the atomic blocks of WGMMA as wA and wB.
Matrices wA and wB are divided into many smaller matrices, called core matrices. Each core matrix has a stride direction and a continuation direction, with a length of 8 in the stride direction and a length of 16 bytes in the continuation direction. Matrix wA consists of 8x2 core matrices, and matrix wB consists of 2x(N/8) core matrices. We use the core matrices to show the block division of wA and wB as follows (picture taken from PTX document):
As mentioned above, wgmma in synchronized stream mode (SS mode) requires matrix descriptors, namely the descriptor of wA (desc-a) and the descriptor of wB (desc-b) as input. This descriptor encodes five parameters:
Starting address: The starting base address of the operand in the shared memory (SMEM). Leading dimension byte offset (LBO): The byte distance between two adjacent core matrices in the K dimension. Stride dimension byte offset (SBO): The byte distance between two adjacent core matrices in M or N dimensions. Shuffle mode: None, 32 bytes, 64 bytes, or 128 bytes. Matrix base offset: This offset is used to resolve the alignment issue of shared memory when the shared memory address is not aligned with the byte boundary of the repeating pattern in shuffle mode.
The first byte offset (LBO) and the stride byte offset (SBO) are marked in the figure above.
CUTLASS make_gmma_desc
The method constructs a descriptor based on the layout of the shared memory (SMEM) tensor provided as input (as GmmaDescriptor
As long as the layout of the input tensor uses one of the eight canonical general matrix multiplication (GMMA) layout atoms and tile_to_shape
to create (as described in detail previously in “Shared Memory Layout Constraints for WGMMA”),make_gmma_desc
The first byte offset (LBO) and the stride byte offset (SBO) are accurately calculated, the shuffle mode is determined, and the descriptor is constructed. For example,GmmaDescriptor
describes the following acceptable WGMMA layouts for K-first storage (where T*sizeof(dtype)=16
):
No swizzle: Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO))
32-byte swizzle : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T ))
64-byte swizzle : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T ))
128-byte swizzle : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T ))
Most notably, for the 64-byte and 128-byte shuffle patterns, the strides are such that a given acceptable WGMMA layout is not a compact layout. Instead, there are 2 or 4 sets of WGMMA atomic operands stacked side by side in blocks in the K direction, resulting in strides of 4T and 8T in the M-mode of the core matrices. In other words, when shuffling, 2, 4, or 8 core matrices that are logically adjacent in the K-mode are interleaved in memory and belong to different WGMMA atoms for the 64-byte and 128-byte shuffle patterns.
For completeness, we also give the acceptable WGMMA layout in the case of MN-first storage:
No swizzle: Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO))
32-byte swizzle : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO))
64-byte swizzle : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO))
128-byte swizzle : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO))
Summarize
In [Part 1] of the General Matrix Multiplication (GEMM) series, we explored the core concepts involved in using warp group matrix multiplication and accumulation (WGMMA) as the basic operation in GEMM based on the (Hopper) architecture.
WGMMA requires a warp group of 128 threads to perform matrix multiplications cooperatively, and can only operate on specific fragments of the matrix. We delve into the special shapes and layouts involved, focusing on how to use the canonical GMMA layout => tile_to_shape pattern to construct an operand layout that is guaranteed to be accepted by WGMMA.
To ensure that its usage behavior is clear, WGMMA also requires a specific synchronization mechanism. To this end, we explain wgmma.fence
,fence.proxy.async
,wgmma.commit_group
and wgmma.wait_group
and wgmma.mma_async
The relationship and use between them.
Finally, we explain in detail the inner workings of the WGMMA core matrix and how CUTLASS builds matrix descriptors for operands originating from shared memory (SMEM).
Overall, this blog post should enable programmers to write CUTLASS kernels that use WGMMA on the Hopper architecture. In [Part 2], we will expand the discussion to introduce the Tensor Memory Access (TMA) technique and how to use TMA and WGMMA together in the GEMM kernel of the Hopper architecture to achieve overlapping operations of data copying and computation.