Tiled Matmul

While @tile and @parallelize are powerful tools for data manipulation, they can be cumbersome. TiledCopy and TiledMMA simplify this process.

Tiled Copy

TiledCopy streamlines data transfer between arrays. Consider an example where six threads copy a 4x9 src array to a dst array of the same shape. The mapping of logical coordinates to thread IDs is as follows:

1 1 1 2 2 2 3 3 3
1 1 1 2 2 2 3 3 3
4 4 4 5 5 5 6 6 6
4 4 4 5 5 5 6 6 6

Each thread is assigned a data segment defined by val_layout (2,3):(1,2), while the thread group operates within thr_layout (2,3):(3,1).

First, initialize the arrays:

julia> using MoYe
julia> src_buffer = collect(1:36) .* 0.1;
julia> src = MoYeArray(src_buffer, @Layout((4,9)))4×9 MoYeArray{Float64, 2, ViewEngine{Float64, Ptr{Float64}}, StaticLayout{2, Tuple{Static.StaticInt{4}, Static.StaticInt{9}}, Tuple{Static.StaticInt{1}, Static.StaticInt{4}}}} with indices _1:_4×_1:_9: 0.1 0.5 0.9 1.3 1.7 2.1 2.5 2.9 3.3 0.2 0.6 1.0 1.4 1.8 2.2 2.6 3.0 3.4 0.3 0.7 1.1 1.5 1.9 2.3 2.7 3.1 3.5 0.4 0.8 1.2 1.6 2.0 2.4 2.8 3.2 3.6
julia> dst_buffer = zeros(36);
julia> dst = MoYeArray(dst_buffer, make_layout((_4,_9)));

Next, set up the TiledCopy:

julia> thr_layout = @Layout (2, 3) (3, 1)(_2, _3):(_3, _1)
julia> val_layout = @Layout (2, 3) (1, 2)(_2, _3):(_1, _2)
julia> tiled_copy = make_tiled_copy( CopyAtom{UniversalCopy{Float64}, Float64}(), thr_layout, val_layout)TiledCopy Tiler_MN: (_4, _9) TiledLayout_TV: ((_3, _2), (_2, _3)):((_12, _2), (_1, _4)) CopyAtom Thread ID: _1:_0 ValLayoutSrc: (_1, _1):(_0, _1) ValLayoutDst: (_1, _1):(_0, _1) ValLayoutRef: (_1, _1):(_0, _1) ValueType: 64b

The Float64 in CopyAtom specifies the data type. UniversalCopy{Float64} indicates a non-vectorized copy. For vectorized copies, use a type like UInt128:

tiled_copy_vec = make_tiled_copy(
	CopyAtom{UniversalCopy{UInt128}, Float64}(),
	thr_layout, 
	val_layout)

Note that for vectorized copies, val_layout must have a divisible number of elements.

Visualize the tiled_copy using print_typst(tiled_copy) in the Typst web app:

matmuil

The two tables show the thread distribution for src and dst. PTX instructions may reallocate each thread's data. For example:

print_typst(make_tiled_copy(MoYe.CopyAtom{LDSM_U32x4_N, UInt16}(),
                                          @Layout((16,2)), @Layout((2,4))));

matmuil

As shown, thr_layout and val_layout are defined on dst. We will revisit ldmatrix when discussing Tensor Cores.

After creating the tiled_copy, partition the data:

julia> thr_idx = 2;
julia> thr_copy = get_slice(tiled_copy, thr_idx);
julia> dst_t = partition_D(thr_copy, dst);
julia> dst_t.layout((_1, (_2, _3)), _1, _1):((_0, (_1, _4)), _0, _0)
julia> src_t = partition_S(thr_copy, src);
julia> src_t.layout((_1, (_2, _3)), _1, _1):((_0, (_1, _4)), _0, _0)
julia> copyto!(tiled_copy, dst_t, src_t);
julia> dst4×9 MoYeArray{Float64, 2, ViewEngine{Float64, Ptr{Float64}}, StaticLayout{2, Tuple{Static.StaticInt{4}, Static.StaticInt{9}}, Tuple{Static.StaticInt{1}, Static.StaticInt{4}}}} with indices _1:_4×_1:_9: 0.0 0.0 0.0 1.3 1.7 2.1 0.0 0.0 0.0 0.0 0.0 0.0 1.4 1.8 2.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

The second thread has now completed its copy. The shape of dst_t is (CPY, CPY_M, CPY_K), where CPY is the number of vectorized values per thread. In this case, it's 1. Changing to UniversalCopy{UInt128} would alter this.

The NVIDIA Ampere architecture supports cuda::memcpy_async for asynchronous data copies between global and shared memory. In older architectures, this required intermediate registers:

julia> thr_idx = 3;
julia> thr_copy = get_slice(tiled_copy, thr_idx);
julia> dst_t = partition_D(thr_copy, dst);
julia> src_t = partition_S(thr_copy, src);
julia> dst_r = make_fragment_like(dst_t);
julia> copyto!(tiled_copy, dst_r, src_t);
julia> copyto!(tiled_copy, dst_t, dst_r);
julia> dst4×9 MoYeArray{Float64, 2, ViewEngine{Float64, Ptr{Float64}}, StaticLayout{2, Tuple{Static.StaticInt{4}, Static.StaticInt{9}}, Tuple{Static.StaticInt{1}, Static.StaticInt{4}}}} with indices _1:_4×_1:_9: 0.0 0.0 0.0 1.3 1.7 2.1 2.5 2.9 3.3 0.0 0.0 0.0 1.4 1.8 2.2 2.6 3.0 3.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

TiledMMA

TiledMMA simplifies MMA partitions. Invoke make_tiled_mma as follows:

julia> mma_C = make_tiled_mma(UniversalFMA{TA,TB, TC}(), # MMA operation
                              @Layout((16,16)))          # Atom layoutERROR: UndefVarError: `TA` not defined in `Main`
Suggestion: check for spelling errors or missing imports.

You can replace UniversalFMA with other MMAOp types. View the predefined MMAOps with:

julia> MoYe.mma_ops_list51-element Vector{Any}:
          "MMAOP_8x8x4_F64F64F64F64_TN" => "llvm.nvvm.mma.m8n8k4.row.col.f64"
       "MMAOP_16x8x4_F32TF32TF32F32_TN" => "llvm.nvvm.mma.m16n8k4.row.col.tf32"
       "MMAOP_16x8x8_F32TF32TF32F32_TN" => "llvm.nvvm.mma.m16n8k8.row.col.tf32"
      "MMAOP_16x8x16_F32BF16BF16F32_TN" => "llvm.nvvm.mma.m16n8k16.row.col.bf16"
       "MMAOP_16x8x8_F32BF16BF16F32_TN" => "llvm.nvvm.mma.m16n8k8.row.col.bf16"
          "MMAOP_8x8x4_F16F16F16F16_TT" => "llvm.nvvm.mma.m8n8k4.row.row.f16.f16"
          "MMAOP_8x8x4_F16F16F16F16_NT" => "llvm.nvvm.mma.m8n8k4.col.row.f16.f16"
          "MMAOP_8x8x4_F16F16F16F16_TN" => "llvm.nvvm.mma.m8n8k4.row.col.f16.f16"
          "MMAOP_8x8x4_F16F16F16F16_NN" => "llvm.nvvm.mma.m8n8k4.col.col.f16.f16"
          "MMAOP_8x8x4_F32F16F16F16_TT" => "llvm.nvvm.mma.m8n8k4.row.row.f32.f16"
                                        ⋮
 "MMAOP_16x8x32_S32U8S8S32_TN_SATURATE" => "llvm.nvvm.mma.m16n8k32.row.col.satfinite.u8.s8"
          "MMAOP_16x8x32_S32U8U8S32_TN" => "llvm.nvvm.mma.m16n8k32.row.col.u8"
 "MMAOP_16x8x32_S32U8U8S32_TN_SATURATE" => "llvm.nvvm.mma.m16n8k32.row.col.satfinite.u8"
  "MMAOP_8x8x128_S32B1B1S32_TN_XORPOPC" => "llvm.nvvm.mma.xor.popc.m8n8k128.row.col.b1"
  "MMAOP_8x8x128_S32B1B1S32_TN_ANDPOPC" => "llvm.nvvm.mma.and.popc.m8n8k128.row.col.b1"
 "MMAOP_16x8x128_S32B1B1S32_TN_XORPOPC" => "llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1"
 "MMAOP_16x8x128_S32B1B1S32_TN_ANDPOPC" => "llvm.nvvm.mma.and.popc.m16n8k128.row.col.b1"
 "MMAOP_16x8x256_S32B1B1S32_TN_XORPOPC" => "llvm.nvvm.mma.xor.popc.m16n8k256.row.col.b1"
 "MMAOP_16x8x256_S32B1B1S32_TN_ANDPOPC" => "llvm.nvvm.mma.and.popc.m16n8k256.row.col.b1"
thr_mma = get_slice(mma_C, threadIdx().x);
tCsA = partition_A(sA);
tCsB = partition_B(sB);
tCgC = partition_C(gC);

tCrC = make_fragment_like(tCgC)

These instructions operate on Tensor Cores, which are covered in a later section.

Matmul with Tiled Operations

Now, let's upgrade the matmul_kernel with TiledCopy and TiledMMA.

function matmul_kernel(A, sA_layout, copy_A,
                       B, sB_layout, copy_B,
                       C, mma_C)
    sA = MoYeSharedArray(eltype(A), sA_layout)
    sB = MoYeSharedArray(eltype(B), sB_layout)

    mA = MoYeArray(A)
    mB = MoYeArray(B)
    mC = MoYeArray(C)

    bM = size(sA_layout, 1)
    bN = size(sB_layout, 1)
    bK = size(sB_layout, 2)
    
    gA = @tile mA (bM, bK) (blockIdx().x, :)
    gB = @tile mB (bN, bK) (blockIdx().y, :)
    gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)

    # Copy partition
    thr_copy_a = get_slice(copy_A, threadIdx().x)      
    tAgA = partition_S(thr_copy_a, gA)                 # (CPY, CPY_M, CPY_K, k)
    tAsA = partition_D(thr_copy_a, sA)                 # (CPY, CPY_M, CPY_K)
    tArA = make_fragment_like(tAsA)                    # (CPY, CPY_M, CPY_K)

    thr_copy_b = get_slice(copy_B, threadIdx().x)
    tBgB = partition_S(thr_copy_b, gB)                 # (CPY, CPY_N, CPY_K, k)
    tBsB = partition_D(thr_copy_b, sB)                 # (CPY, CPY_N, CPY_K)
    tBrB = make_fragment_like(tBsB)                    # (CPY, CPY_N, CPY_K)

    # MMA partition
    thr_mma = get_slice(mma_C, threadIdx().x)
    tCsA = partition_A(thr_mma, sA)                    # (MMA, MMA_M, MMA_K)
    tCsB = partition_B(thr_mma, sB)                    # (MMA, MMA_M, MMA_K)
    tCgC = partition_C(thr_mma, gC)                    # (MMA, MMA_M, MMA_N)

    # Overlap copy and compute
    copyto!(copy_A, tArA, view(tAgA, :, :, :, _1))
    copyto!(copy_B, tBrB, view(tBgB, :, :, :, _1))

    # Accumulator
    tCrC = make_fragment_C(thr_mma, tCgC)
    zeros!(tCrC)

    k_max = size(tAgA, 4)
    for k in 1:k_max
        sync_threads()
        copyto!(tAsA, tArA)
        copyto!(tBsB, tBrB)
        sync_threads()

	    # Load the next tile
	    k_next = k < k_max ? k+1 : k
	    copyto!(copy_A, tArA, view(tAgA, :, :, :, k_next))
	    copyto!(copy_B, tBrB, view(tBgB, :, :, :, k_next))

        @gc_preserve gemm!(mma_C, tCrC, tCsA, tCsB, tCrC)
    end

    copyto!(tCgC, tCrC)
    return nothing
end


function matmul(A, B, C)
    bM = _128
    bN = _128
    bK = _8
    
    sA_layout = make_layout((bM, bK), (_1, bM + _1))
    sB_layout = make_layout((bN, bK), (_1, bN + _1))

    TA = eltype(A)
    TB = eltype(B)
    TC = eltype(C)
	
    copy_A = make_tiled_copy(CopyAtom{UniversalCopy{TA}, TA}(),
                             @Layout((32, 8)),
                             @Layout((1, 1)))
    copy_B = make_tiled_copy(CopyAtom{UniversalCopy{TB}, TB}(),
                             @Layout((32, 8)),
                             @Layout((1, 1)))

    mma_C = make_tiled_mma(UniversalFMA{TA,TB, TC}(), # MMA operation
                           @Layout((32,8)))          # Atom layout

    threads = Int(size(mma_C))
    blocks = (cld(size(A, 1), bM), cld(size(B, 1), bN))

    @cuda threads=threads blocks=blocks matmul_kernel(A, sA_layout, copy_A,
                                                      B, sB_layout, copy_B,
                                                      C, mma_C)
end

function test()
    A =  CUDA.randn(Float32, 2048, 256)
    B =  CUDA.randn(Float32, 2048, 256)
    C =  CUDA.randn(Float32, 2048, 2048)
    matmul(A, B, C)
    CUDA.synchronize()
    @test C == A * B'
    CUDA.unsafe_free!(A)
    CUDA.unsafe_free!(B)
    CUDA.unsafe_free!(C)
end

test()

Vectorized Copy and Memory Coalescing

As mentioned, you can use UniversalCopy{Float64} or UniversalCopy{UInt128} for vectorized copies. However, it is crucial to ensure that memory accesses are coalesced.

An uncoalesced copy:

copy_A = make_tiled_copy(CopyAtom{UniversalCopy{Float64}, TA}(),
                             @Layout((32, 8)),
                             @Layout((4, 1)))

Here, thread 1 loads from [1], [2] and thread 2 loads from [5], [6], which is not coalesced.

Coalesced copies:

copy_A = make_tiled_copy(CopyAtom{UniversalCopy{Float64}, TA}(),
                             @Layout((32, 8)),
                             @Layout((2, 1)))
copy_A = make_tiled_copy(CopyAtom{UniversalCopy{UInt128}, TA}(),
                             @Layout((32, 8)),
                             @Layout((4, 1)))          

In these examples, threads access contiguous memory locations, leading to coalesced memory access and better performance.