Tiled Copy

We have already introduced how to copy data using @tile and @parallelize. This process might still appear somewhat cumbersome, and TiledCopy serves to simplify it.

Consider the following example where we employ six threads to transfer an array src of shape (4, 9) into another array dst with the identical shape. The relationship mapping logic coordinates to thread IDs can be visualized as:

1 1 1 2 2 2 3 3 3
1 1 1 2 2 2 3 3 3
4 4 4 5 5 5 6 6 6
4 4 4 5 5 5 6 6 6

Here, each thread is assigned a data segment defined by the layout (2,3):(1,2). The group of threads operates within a layout of (2,3):(3,1), referred to as val_layout and thr_layout, respectively.

To begin, we initialize these arrays:

julia> using MoYe
julia> src_buffer = collect(1:36) .* 0.1;
julia> src = MoYeArray(src_buffer, @Layout((4,9)))4×9 MoYeArray{Float64, 2, ViewEngine{Float64, Ptr{Float64}}, StaticLayout{2, Tuple{Static.StaticInt{4}, Static.StaticInt{9}}, Tuple{Static.StaticInt{1}, Static.StaticInt{4}}}} with indices _1:_4×_1:_9: 0.1 0.5 0.9 1.3 1.7 2.1 2.5 2.9 3.3 0.2 0.6 1.0 1.4 1.8 2.2 2.6 3.0 3.4 0.3 0.7 1.1 1.5 1.9 2.3 2.7 3.1 3.5 0.4 0.8 1.2 1.6 2.0 2.4 2.8 3.2 3.6
julia> dst_buffer = zeros(36);
julia> dst = MoYeArray(dst_buffer, make_layout((_4,_9)));

We then proceed to set up a TiledCopy:

julia> thr_layout = @Layout (2, 3) (3, 1)(_2, _3):(_3, _1)
julia> val_layout = @Layout (2, 3) (1, 2)(_2, _3):(_1, _2)
julia> tiled_copy = make_tiled_copy( CopyAtom{UniversalCopy{Float64}, Float64}(), thr_layout, val_layout)TiledCopy Tiler_MN: (_4, _9) TiledLayout_TV: ((_3, _2), (_2, _3)):((_12, _2), (_1, _4)) CopyAtom Thread ID: _1:_0 ValLayoutSrc: (_1, _1):(_0, _1) ValLayoutDst: (_1, _1):(_0, _1) ValLayoutRef: (_1, _1):(_0, _1) ValueType: 64b

The second parameter Float64 in CopyAtom indicates that the copied data is of Float64 type. UniversalCopy{Float64} is used for vectorized copy operations, meaning that the data is recast to Float64, i.e., without vectorization. Here is a vectorized TiledCopy

tiled_copy_vec = make_tiled_copy(
	CopyAtom{UniversalCopy{UInt128}, Float64}(),
	thr_layout, 
	val_layout)

Note that vectorized copy must be comatiable with val_layout, i.e., val_layout needs to have enough and divisible number of elements to be vectorized.

You can visualize this tiled_copy by using print_typst(tiled_copy). Visit typst, copy the printed string, and you will see the following image:

matmuil

The two tables respectively represent the thread distribution of src and dst, which are the same here. There are also some PTX instructions involved in reallocating each thread's data, for example:

print_typst(make_tiled_copy(MoYe.CopyAtom{LDSM_U32x4_N, UInt16}(),
                                          @Layout((16,2)), @Layout((2,4))));

matmuil

As you can see, both thrlayout and vallayout are actually defined on dst.

We will go back to ldmatrix when we talk about tensor cores.

Returning to our example, after making the tiled_copy, we can use it to partition data.

julia> thr_idx = 2;
julia> thr_copy = get_slice(tiled_copy, thr_idx);
julia> dst_t = partition_D(thr_copy, dst);
julia> dst_t.layout((_1, (_2, _3)), _1, _1):((_0, (_1, _4)), _0, _0)
julia> src_t = partition_S(thr_copy, src);
julia> src_t.layout((_1, (_2, _3)), _1, _1):((_0, (_1, _4)), _0, _0)
julia> copyto!(tiled_copy, dst_t, src_t);
julia> dst4×9 MoYeArray{Float64, 2, ViewEngine{Float64, Ptr{Float64}}, StaticLayout{2, Tuple{Static.StaticInt{4}, Static.StaticInt{9}}, Tuple{Static.StaticInt{1}, Static.StaticInt{4}}}} with indices _1:_4×_1:_9: 0.0 0.0 0.0 1.3 1.7 2.1 0.0 0.0 0.0 0.0 0.0 0.0 1.4 1.8 2.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

You can see that the second thread has completed the copy. The shape of dst_t is (CPY, CPY_M, CPY_K) representing the the num of values handle by a thread in a single tile, and the demensions tiled in dst's shape. Notably, the left most mode of CPY stands for the number of vectorized values. In this case it is 1, but try changing to UniversalCopy{UInt128} and see how the result changes.

The NVIDIA Ampere architecture supports cuda::memcpy_async for asynchronously copying data between GPU global memory and shared memory without needing threads to orchestrate the data movement. In previous architectures, copying from global memory to shared memory usually involved registers for intermediation, corresponding to this syntax:

julia> thr_idx = 3;
julia> thr_copy = get_slice(tiled_copy, thr_idx);
julia> dst_t = partition_D(thr_copy, dst);
julia> src_t = partition_S(thr_copy, src);
julia> dst_r = make_fragment_like(dst_t);
julia> copyto!(tiled_copy, dst_r, src_t);
julia> copyto!(tiled_copy, dst_t, dst_r);
julia> dst4×9 MoYeArray{Float64, 2, ViewEngine{Float64, Ptr{Float64}}, StaticLayout{2, Tuple{Static.StaticInt{4}, Static.StaticInt{9}}, Tuple{Static.StaticInt{1}, Static.StaticInt{4}}}} with indices _1:_4×_1:_9: 0.0 0.0 0.0 1.3 1.7 2.1 2.5 2.9 3.3 0.0 0.0 0.0 1.4 1.8 2.2 2.6 3.0 3.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

TiledMMA

In this section, we'll show you how to use TiledMMA to replace an mma partition. First, invoke the function maketiledmma as follows:

julia> mma_C = make_tiled_mma(UniversalFMA{TA,TB, TC}(), # MMA operation
                              @Layout((16,16)))          # Atom layoutERROR: UndefVarError: `TA` not defined

You can experiment with replacing UniversalFMA with another MMAOp and use print_typst to view the results. Here are the predefined MMAOps:

julia> MoYe.mma_ops_list51-element Vector{Any}:
          "MMAOP_8x8x4_F64F64F64F64_TN" => "llvm.nvvm.mma.m8n8k4.row.col.f64"
       "MMAOP_16x8x4_F32TF32TF32F32_TN" => "llvm.nvvm.mma.m16n8k4.row.col.tf32"
       "MMAOP_16x8x8_F32TF32TF32F32_TN" => "llvm.nvvm.mma.m16n8k8.row.col.tf32"
      "MMAOP_16x8x16_F32BF16BF16F32_TN" => "llvm.nvvm.mma.m16n8k16.row.col.bf16"
       "MMAOP_16x8x8_F32BF16BF16F32_TN" => "llvm.nvvm.mma.m16n8k8.row.col.bf16"
          "MMAOP_8x8x4_F16F16F16F16_TT" => "llvm.nvvm.mma.m8n8k4.row.row.f16.f16"
          "MMAOP_8x8x4_F16F16F16F16_NT" => "llvm.nvvm.mma.m8n8k4.col.row.f16.f16"
          "MMAOP_8x8x4_F16F16F16F16_TN" => "llvm.nvvm.mma.m8n8k4.row.col.f16.f16"
          "MMAOP_8x8x4_F16F16F16F16_NN" => "llvm.nvvm.mma.m8n8k4.col.col.f16.f16"
          "MMAOP_8x8x4_F32F16F16F16_TT" => "llvm.nvvm.mma.m8n8k4.row.row.f32.f16"
                                        ⋮
 "MMAOP_16x8x32_S32U8S8S32_TN_SATURATE" => "llvm.nvvm.mma.m16n8k32.row.col.satfinite.u8.s8"
          "MMAOP_16x8x32_S32U8U8S32_TN" => "llvm.nvvm.mma.m16n8k32.row.col.u8"
 "MMAOP_16x8x32_S32U8U8S32_TN_SATURATE" => "llvm.nvvm.mma.m16n8k32.row.col.satfinite.u8"
  "MMAOP_8x8x128_S32B1B1S32_TN_XORPOPC" => "llvm.nvvm.mma.xor.popc.m8n8k128.row.col.b1"
  "MMAOP_8x8x128_S32B1B1S32_TN_ANDPOPC" => "llvm.nvvm.mma.and.popc.m8n8k128.row.col.b1"
 "MMAOP_16x8x128_S32B1B1S32_TN_XORPOPC" => "llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1"
 "MMAOP_16x8x128_S32B1B1S32_TN_ANDPOPC" => "llvm.nvvm.mma.and.popc.m16n8k128.row.col.b1"
 "MMAOP_16x8x256_S32B1B1S32_TN_XORPOPC" => "llvm.nvvm.mma.xor.popc.m16n8k256.row.col.b1"
 "MMAOP_16x8x256_S32B1B1S32_TN_ANDPOPC" => "llvm.nvvm.mma.and.popc.m16n8k256.row.col.b1"
thr_mma = get_slice(mma_C, threadIdx().x);
tCsA = partition_A(sA);
tCsB = partition_B(sB);
tCgC = partition_C(gC);

tCrC = make_fragment_like(tCgC)

These instructions operate on tensor cores, a topic we haven't covered yet (but will soon!).

MatMul

Now, we use TiledCopy and TiledMMA to upgrade the previous matmul_kernel.

function matmul_kernel(A, sA_layout, copy_A,
                       B, sB_layout, copy_B,
                       C, mma_C)
    sA = MoYeSharedArray(eltype(A), sA_layout)
    sB = MoYeSharedArray(eltype(B), sB_layout)

    mA = MoYeArray(A)
    mB = MoYeArray(B)
    mC = MoYeArray(C)

    bM = size(sA_layout, 1)
    bN = size(sB_layout, 1)
    bK = size(sB_layout, 2)
    
    gA = @tile mA (bM, bK) (blockIdx().x, :)
    gB = @tile mB (bN, bK) (blockIdx().y, :)
    gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)

    # copy partition
    thr_copy_a = get_slice(copy_A, threadIdx().x)      
    tAgA = partition_S(thr_copy_a, gA)                 # (CPY, CPY_M, CPY_K, k)
    tAsA = partition_D(thr_copy_a, sA)                 # (CPY, CPY_M, CPY_K)
    tArA = make_fragment_like(tAsA)                    # (CPY, CPY_M, CPY_K)

    thr_copy_b = get_slice(copy_B, threadIdx().x)
    tBgB = partition_S(thr_copy_b, gB)                 # (CPY, CPY_N, CPY_K, k)
    tBsB = partition_D(thr_copy_b, sB)                 # (CPY, CPY_N, CPY_K)
    tBrB = make_fragment_like(tBsB)                    # (CPY, CPY_N, CPY_K)

    # mma partition
    thr_mma = get_slice(mma_C, threadIdx().x)
    tCsA = partition_A(thr_mma, sA)                    # (MMA, MMA_M, MMA_K)
    tCsB = partition_B(thr_mma, sB)                    # (MMA, MMA_M, MMA_K)
    tCgC = partition_C(thr_mma, gC)                    # (MMA, MMA_M, MMA_N)

    # overlap copy and compute
    copyto!(copy_A, tArA, view(tAgA, :, :, :, _1))
    copyto!(copy_B, tBrB, view(tBgB, :, :, :, _1))

    # accumulator
    tCrC = make_fragment_C(thr_mma, tCgC)
    zeros!(tCrC)

    k_max = size(tAgA, 4)
    for k in 1:k_max
        sync_threads()
        copyto!(tAsA, tArA)
        copyto!(tBsB, tBrB)
        sync_threads()

	    # load the next tile
	    k_next = k < k_max ? k+1 : k
	    copyto!(copy_A, tArA, view(tAgA, :, :, :, k_next))
	    copyto!(copy_B, tBrB, view(tBgB, :, :, :, k_next))

        @gc_preserve gemm!(mma_C, tCrC, tCsA, tCsB, tCrC)
    end

    copyto!(tCgC, tCrC)
    return nothing
end


function matmul(A, B, C)
    bM = _128
    bN = _128
    bK = _8
    
    sA_layout = make_layout((bM, bK), (_1, bM + _1))
    sB_layout = make_layout((bN, bK), (_1, bN + _1))

    TA = eltype(A)
    TB = eltype(B)
    TC = eltype(C)
	
    copy_A = make_tiled_copy(CopyAtom{UniversalCopy{TA}, TA}(),
                             @Layout((32, 8)),
                             @Layout((1, 1)))
    copy_B = make_tiled_copy(CopyAtom{UniversalCopy{TB}, TB}(),
                             @Layout((32, 8)),
                             @Layout((1, 1)))

    mma_C = make_tiled_mma(UniversalFMA{TA,TB, TC}(), # MMA operation
                           @Layout((32,8)))          # Atom layout

    threads = Int(size(mma_C))
    blocks = (cld(size(A, 1), bM), cld(size(B, 1), bN))

    @cuda threads=threads blocks=blocks matmul_kernel(A, sA_layout, copy_A,
                                                      B, sB_layout, copy_B,
                                                      C, mma_C)
end

function test()
    A =  CUDA.randn(Float32, 2048, 256)
    B =  CUDA.randn(Float32, 2048, 256)
    C =  CUDA.randn(Float32, 2048, 2048)
    matmul(A, B, C)
    CUDA.synchronize()
    @test C == A * B'
    CUDA.unsafe_free!(A)
    CUDA.unsafe_free!(B)
    CUDA.unsafe_free!(C)
end

test()

Vectorized copy

As previously mentioned, you can change to UniversalCopy{Float64} or UniversalCopy{UInt128} to enabled vectoried copy. But we also need to keep in mind the copies are coalesced. For example, the following one is not coalesced

copy_A = make_tiled_copy(CopyAtom{UniversalCopy{Float64}, TA}(),
                             @Layout((32, 8)),
                             @Layout((4, 1)))

since thread 1 is loading from [1], [2] and thead 2 is loading from [5], [6].

Theses are coalesced:

copy_A = make_tiled_copy(CopyAtom{UniversalCopy{Float64}, TA}(),
                             @Layout((32, 8)),
                             @Layout((2, 1)))
copy_A = make_tiled_copy(CopyAtom{UniversalCopy{UInt128}, TA}(),
                             @Layout((32, 8)),
                             @Layout((4, 1)))