Tensor Cores
Tensor Cores are specialized hardware accelerators designed to optimize matrix operations, crucial for deep learning and AI algorithms.
Enabling Tensor Cores can be as simple as modifying a single line in the matmul_kernel
function:
mma = make_tiled_mma(MMAOP_8x8x4_F32F16F16F32_NT(),
atom_layout,
tiler)
The NT
in MMAOP_8x8x4_F32F16F16F32_NT
indicates that matrix A is in M-major order and matrix B is in N-major order.
Let's explore a minimal example:
mma = make_tiled_mma(MMAOP_16x8x8_F32TF32TF32F32_TN())
print_typst(mma)
At first glance, the diagram may seem complex, but the concept is straightforward. Threads collectively load data from matrices A
and B
according to the specified layout. During the matrix multiply-accumulate (MMA) computation, data is internally shared among threads—a process not transparent to the user. Once the computation is complete, each thread stores the results as dictated by the layout of matrix C
.
function matmul_kernel(A, sA_layout, copy_A,
B, sB_layout, copy_B,
C, mma)
sA = MoYeSharedArray(eltype(A), sA_layout)
sB = MoYeSharedArray(eltype(B), sB_layout)
mA = MoYeArray(A)
mB = MoYeArray(B)
mC = MoYeArray(C)
bM = size(sA_layout, 1)
bN = size(sB_layout, 1)
bK = size(sB_layout, 2)
gA = @tile mA (bM, bK) (blockIdx().x, :)
gB = @tile mB (bN, bK) (blockIdx().y, :)
gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)
# Copy partition
thr_copy_a = get_slice(copy_A, threadIdx().x)
tAgA = partition_S(thr_copy_a, gA) # (CPY, CPY_M, CPY_K, k)
tAsA = partition_D(thr_copy_a, sA) # (CPY, CPY_M, CPY_K)
thr_copy_b = get_slice(copy_B, threadIdx().x)
tBgB = partition_S(thr_copy_b, gB) # (CPY, CPY_N, CPY_K, k)
tBsB = partition_D(thr_copy_b, sB) # (CPY, CPY_N, CPY_K)
# MMA partition
thr_mma = get_slice(mma, threadIdx().x)
tCsA = partition_A(thr_mma, sA) # (MMA, MMA_M, MMA_K)
tCsB = partition_B(thr_mma, sB) # (MMA, MMA_M, MMA_K)
tCgC = partition_C(thr_mma, gC) # (MMA, MMA_M, MMA_N)
tCrA = make_fragment_A(thr_mma, tCsA) # (MMA, MMA_M, MMA_K)
tCrB = make_fragment_B(thr_mma, tCsB)
tCrC = make_fragment_C(thr_mma, tCgC)
zeros!(tCrC)
# Copy from global to shared memory
copyto!(copy_A, tAsA, view(tAgA, :, :, :, _1))
copyto!(copy_B, tBsB, view(tBgB, :, :, :, _1))
cp_async_wait()
# Copy from shared memory to registers
copyto!(tCrA, tCsA)
copyto!(tCrB, tCsB)
@gc_preserve gemm!(mma, tCrC, tCrA, tCrB, tCrC)
copyto!(tCgC, tCrC)
@inbounds tCrC[1] # Compiler bug: must load after copyto!
return nothing
end
function matmul(A, B, C)
bM = _16
bN = _8
bK = _8
sA_layout = make_layout((bM, bK), (_1, bM))
sB_layout = make_layout((bN, bK), (bK, _1))
TA = eltype(A)
TB = eltype(B)
TC = eltype(C)
copy_A = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{UInt128}, TA}(),
@Layout((4, 8)),
@Layout((4, 1)))
copy_B = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{UInt64}, TB}(),
@Layout((8, 4), (4, 1)),
@Layout((1, 2)))
mma = make_tiled_mma(MMAOP_16x8x8_F32TF32TF32F32_TN())
threads = Int(size(mma))
blocks = (cld(size(A, 1), bM), cld(size(B, 1), bN))
@cuda threads=threads blocks=blocks matmul_kernel(A, sA_layout, copy_A,
B, sB_layout, copy_B,
C, mma)
end
function test()
A = CuArray(reshape(collect(1:16*8) .* 1f0, (16,8)))
B = CuArray(reshape(collect(1:8*8) .* 1f0, (8,8)))
C = CuArray(ones(Float32, (16,8)))
matmul(A, B', C)
CUDA.synchronize()
@test C == A * B
CUDA.unsafe_free!(A)
CUDA.unsafe_free!(B)
CUDA.unsafe_free!(C)
end
LDMatrix
The ldmatrix
instruction loads data from shared memory into registers and shuffles it to align with a Tensor Core MMA operation. Given a Tensor Core MMA operation, this shuffling can be "inverted" to obtain a TiledCopy
for the operation.
mma = make_tiled_mma(MMAOP_16x8x8_F32TF32TF32F32_TN())
smem_copy_A = make_tiled_copy_A(CopyAtom{LDSM_U32x4_N, Float32}(), mma)
print_typst(smem_copy_A)
The resulting layout on the right-hand side matches the layout of A in the mma
.
- The
TN
inMMAOP_16x8x8_F32TF32TF32F32_TN
means that both A and B are in K-major order. - The
N
inLDSM_U32x4_N
means the source array is in K-major order. ldmatrix
requires four consecutive threads to load 16 consecutive bytes, meaning the layout ofA
in shared memory must meet this specification.
For B:
smem_copy_B = make_tiled_copy_B(CopyAtom{LDSM_U32x2_N, Float32}(), mma)
print_typst(smem_copy_B)
We then use smem_copy_A
and smem_copy_B
to re-tile the shared memory and registers:
smem_thr_copy_A = get_slice(smem_copy_A, threadIdx().x)
smem_thr_copy_B = get_slice(smem_copy_B, threadIdx().x)
tCsA_retiled = partition_S(smem_thr_copy_A, sA)
tCsB_retiled = partition_S(smem_thr_copy_B, sB)
tCrA_retiled = retile_D(smem_thr_copy_A, tCrA)
tCrB_retiled = retile_D(smem_thr_copy_B, tCrB)
Complete Code
function matmul_kernel(A, sA_layout, gmem_copy_A, smem_copy_A,
B, sB_layout, gmem_copy_B, smem_copy_B,
C, mma)
sA = MoYeSharedArray(eltype(A), sA_layout)
sB = MoYeSharedArray(eltype(B), sB_layout)
mA = MoYeArray(A)
mB = MoYeArray(B)
mC = MoYeArray(C)
bM = size(sA_layout, 1)
bN = size(sB_layout, 1)
bK = size(sB_layout, 2)
gA = @tile mA (bM, bK) (blockIdx().x, :)
gB = @tile mB (bN, bK) (blockIdx().y, :)
gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)
# Gmem copy partition
gmem_thr_copy_a = get_slice(gmem_copy_A, threadIdx().x)
tAgA = partition_S(gmem_thr_copy_a, gA) # (CPY, CPY_M, CPY_K, k)
tAsA = partition_D(gmem_thr_copy_a, sA) # (CPY, CPY_M, CPY_K)
gmem_thr_copy_b = get_slice(gmem_copy_B, threadIdx().x)
tBgB = partition_S(gmem_thr_copy_b, gB) # (CPY, CPY_N, CPY_K, k)
tBsB = partition_D(gmem_thr_copy_b, sB) # (CPY, CPY_N, CPY_K)
# Copy from global to shared memory
copyto!(gmem_copy_A, tAsA, view(tAgA, :, :, :, _1))
copyto!(gmem_copy_B, tBsB, view(tBgB, :, :, :, _1))
# MMA partition
thr_mma = get_slice(mma, threadIdx().x)
tCsA = partition_A(thr_mma, sA) # (MMA, MMA_M, MMA_K)
tCsB = partition_B(thr_mma, sB) # (MMA, MMA_M, MMA_K)
tCgC = partition_C(thr_mma, gC) # (MMA, MMA_M, MMA_N)
tCrA = make_fragment_A(thr_mma, tCsA) # (MMA, MMA_M, MMA_K)
tCrB = make_fragment_B(thr_mma, tCsB) # (MMA, MMA_N, MMA_K)
tCrC = make_fragment_C(thr_mma, tCgC) # (MMA, MMA_M, MMA_N)
zeros!(tCrC)
# Retile
smem_thr_copy_A = get_slice(smem_copy_A, threadIdx().x)
smem_thr_copy_B = get_slice(smem_copy_B, threadIdx().x)
tCsA_retiled = partition_S(smem_thr_copy_A, sA)
tCsB_retiled = partition_S(smem_thr_copy_B, sB)
tCrA_retiled = retile_D(smem_thr_copy_A, tCrA)
tCrB_retiled = retile_D(smem_thr_copy_B, tCrB)
cp_async_wait()
# Copy from shared memory to registers
copyto!(smem_copy_A, tCrA_retiled, tCsA_retiled)
copyto!(smem_copy_B, tCrB_retiled, tCsB_retiled)
@gc_preserve gemm!(mma, tCrC, tCrA, tCrB, tCrC)
copyto!(tCgC, tCrC)
@inbounds tCrC[1] # Compiler bug: must load after copyto!
return nothing
end
function matmul(A, B, C)
bM = _16
bN = _8
bK = _8
sA_layout = make_layout((bM, bK), (_1, bM))
sB_layout = make_layout((bN, bK), (bK, _1))
TA = eltype(A)
TB = eltype(B)
TC = eltype(C)
gmem_copy_A = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{UInt128}, TA}(),
@Layout((4, 8)),
@Layout((4, 1)))
gmem_copy_B = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{UInt64}, TB}(),
@Layout((8, 4), (4, 1)),
@Layout((1, 2)))
mma = make_tiled_mma(MMAOP_16x8x8_F32TF32TF32F32_TN())
# Note: A is M-major, so we can only use `UniversalCopy`
smem_copy_A = make_tiled_copy_A(CopyAtom{UniversalCopy{TA}, TA}(), mma)
smem_copy_B = make_tiled_copy_B(CopyAtom{LDSM_U32x2_N, TB}(), mma)
threads = Int(size(mma))
blocks = (cld(size(A, 1), bM), cld(size(B, 1), bN))
@cuda threads=threads blocks=blocks matmul_kernel(A, sA_layout, gmem_copy_A, smem_copy_A,
B, sB_layout, gmem_copy_B, smem_copy_B,
C, mma)
end
Double Buffering
@views function matmul_kernel(A, sA_layout, gmem_copy_A, smem_copy_A,
B, sB_layout, gmem_copy_B, smem_copy_B,
C, mma)
sA = MoYeSharedArray(eltype(A), sA_layout) # (bM, bK, 2)
sB = MoYeSharedArray(eltype(B), sB_layout) # (bN, bK, 2)
mA = MoYeArray(A)
mB = MoYeArray(B)
mC = MoYeArray(C)
bM = size(sA_layout, 1)
bN = size(sB_layout, 1)
bK = size(sB_layout, 2)
gA = @tile mA (bM, bK) (blockIdx().x, :)
gB = @tile mB (bN, bK) (blockIdx().y, :)
gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)
# Gmem copy partition
gmem_thr_copy_A = get_slice(gmem_copy_A, threadIdx().x)
tAgA = partition_S(gmem_thr_copy_A, gA) # (CPY, CPY_M, CPY_K, k)
tAsA = partition_D(gmem_thr_copy_A, sA) # (CPY, CPY_M, CPY_K, 2)
gmem_thr_copy_B = get_slice(gmem_copy_B, threadIdx().x)
tBgB = partition_S(gmem_thr_copy_B, gB) # (CPY, CPY_N, CPY_K, k)
tBsB = partition_D(gmem_thr_copy_B, sB) # (CPY, CPY_N, CPY_K, 2)
# Copy gmem to smem for k_tile=1
copyto!(gmem_copy_A, tAsA[:, :, :, _1], tAgA[:, :, :, _1])
copyto!(gmem_copy_B, tBsB[:, :, :, _1], tBgB[:, :, :, _1])
# MMA partition
thr_mma = get_slice(mma, threadIdx().x)
tCsA = partition_A(thr_mma, sA) # (MMA, MMA_M, MMA_K, 2)
tCsB = partition_B(thr_mma, sB) # (MMA, MMA_M, MMA_K, 2)
tCgC = partition_C(thr_mma, gC) # (MMA, MMA_M, MMA_N)
tCrA = make_fragment_A(thr_mma, tCsA[:, :, :, _1]) # (MMA, MMA_M, MMA_K)
tCrB = make_fragment_B(thr_mma, tCsB[:, :, :, _1]) # (MMA, MMA_N, MMA_K)
tCrC = make_fragment_C(thr_mma, tCgC) # (MMA, MMA_M, MMA_N)
zeros!(tCrC)
# Retile
smem_thr_copy_A = get_slice(smem_copy_A, threadIdx().x)
smem_thr_copy_B = get_slice(smem_copy_B, threadIdx().x)
tCsA_retiled = partition_S(smem_thr_copy_A, sA) # (MMA, MMA_M, MMA_K, 2)
tCsB_retiled = partition_S(smem_thr_copy_B, sB) # (MMA, MMA_N, MMA_K, 2)
tCrA_retiled = retile_D(smem_thr_copy_A, tCrA) # (MMA, MMA_M, MMA_K)
tCrB_retiled = retile_D(smem_thr_copy_B, tCrB) # (MMA, MMA_N, MMA_K)
cp_async_wait()
sync_threads()
# Copy smem to rmem for k_block=1
smem_read = 1
smem_write = 2
tCsA_p = view(tCsA_retiled, :, :, :, smem_read)
tCsB_p = view(tCsB_retiled, :, :, :, smem_read)
copyto!(smem_copy_A, tCrA_retiled[:, :, _1], tCsA_p[:, :, _1])
copyto!(smem_copy_B, tCrB_retiled[:, :, _1], tCsB_p[:, :, _1])
k_tile_max = size(tAgA, 4)
k_block_max = static_size(tCrA, 3)
for k_tile in 1:k_tile_max
@loopinfo unroll for k_block in _1:k_block_max
k_block_next = k_block + 1
if k_block == k_block_max
cp_async_wait()
sync_threads()
tCsA_p = view(tCsA_retiled, :, :, :, smem_read)
tCsB_p = view(tCsB_retiled, :, :, :, smem_read)
k_block_next = 1
end
copyto!(smem_copy_A, tCrA_retiled[:, :, k_block_next], tCsA_p[:, :, k_block_next])
copyto!(smem_copy_B, tCrB_retiled[:, :, k_block_next], tCsB_p[:, :, k_block_next])
if k_block == _1 && k_tile<k_tile_max
copyto!(gmem_copy_A, tAsA[:, :, :, smem_write], tAgA[:, :, :, k_tile+1])
copyto!(gmem_copy_B, tBsB[:, :, :, smem_write], tBgB[:, :, :, k_tile+1])
smem_read, smem_write = smem_write, smem_read
end
@gc_preserve gemm!(mma, tCrC, tCrA[:, :, k_block], tCrB[:, :, k_block], tCrC)
end
end
copyto!(tCgC, tCrC)
sync_threads()
return nothing
end
function matmul(A, B, C)
bM = _128
bN = _128
bK = _16
TA = eltype(A)
TB = eltype(B)
TC = eltype(C)
mma = make_tiled_mma(MMAOP_16x8x8_F32TF32TF32F32_TN(),
@Layout((2,2,1), (2,1,1)),
(_32,_32,_8))
gmem_copy_A = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{UInt128}, TA}(),
@Layout((16, 8)),
@Layout((4, 1)))
gmem_copy_B = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{UInt128}, TB}(),
@Layout((32, 4), (4, 1)),
@Layout((1, 4)))
# A is M-major, so we cannot use LDSM_U32x4_N
smem_copy_A = make_tiled_copy_A(CopyAtom{UniversalCopy{TA}, TA}(), mma)
smem_copy_B = make_tiled_copy_B(CopyAtom{LDSM_U32x4_N, TB}(), mma)
sA_layout = @Layout (128, 16, 2) (1, 128, 2048)
sB_layout = @Layout (128, 16, 2) (16, 1, 2048)
threads = Int(size(mma))
blocks = (cld(size(A, 1), bM), cld(size(B, 1), bN))
@cuda threads=threads blocks=blocks matmul_kernel(A, sA_layout, gmem_copy_A, smem_copy_A,
B, sB_layout, gmem_copy_B, smem_copy_B,
C, mma)
end
function test()
A = CUDA.randn(Float32, 2048, 256) # M-major
B = CUDA.randn(Float32, 256, 2048) # K-major
C = CUDA.randn(Float32, 2048, 2048)
matmul(A, B', C)
CUDA.synchronize()
@test C == A * B
CUDA.unsafe_free!(A)
CUDA.unsafe_free!(B)
CUDA.unsafe_free!(C)
end
test()