Pipelining

Instruction-level parallelism (ILP) is a technique used to improve the performance of processors by executing multiple instructions simultaneously. In the context of GPU programming, we can apply ILP to overlap memory operations with computation, effectively hiding memory latency.

Overlapping Global-to-Shared Copies with MMA Computation

We can overlap global-to-shared memory copies with MMA (Matrix Multiply-Accumulate) computation. This is achieved by prefetching the next tile of data from global memory into shared memory while the current tile is being processed.

To implement this, we explicitly load data from shared memory to registers for the MMA computation and initiate a new load from global memory to shared memory for the next tile before starting the computation.

function matmul_kernel(A, sA_layout, copy_A,
                       B, sB_layout, copy_B,
                       C, mma_C)
    sA = MoYeSharedArray(eltype(A), sA_layout)
    sB = MoYeSharedArray(eltype(B), sB_layout)

    mA = MoYeArray(A)
    mB = MoYeArray(B)
    mC = MoYeArray(C)

    bM = size(sA_layout, 1)
    bN = size(sB_layout, 1)
    bK = size(sB_layout, 2)

    gA = @tile mA (bM, bK) (blockIdx().x, :)
    gB = @tile mB (bN, bK) (blockIdx().y, :)
    gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)

    # Copy partition
    thr_copy_a = get_slice(copy_A, threadIdx().x)      
    tAgA = partition_S(thr_copy_a, gA)                 # (CPY, CPY_M, CPY_K, k)
    tAsA = partition_D(thr_copy_a, sA)                 # (CPY, CPY_M, CPY_K)

    thr_copy_b = get_slice(copy_B, threadIdx().x)
    tBgB = partition_S(thr_copy_b, gB)                 # (CPY, CPY_N, CPY_K, k)
    tBsB = partition_D(thr_copy_b, sB)                 # (CPY, CPY_N, CPY_K)

    # Copy gmem to smem for k_tile=1
    copyto!(copy_A, tAsA, view(tAgA, :, :, :, _1))
    copyto!(copy_B, tBsB, view(tBgB, :, :, :, _1))

    # MMA partition
    thr_mma = get_slice(mma_C, threadIdx().x)
    tCsA = partition_A(thr_mma, sA)                    # (MMA, MMA_M, MMA_K)
    tCsB = partition_B(thr_mma, sB)                    # (MMA, MMA_M, MMA_K)
    tCgC = partition_C(thr_mma, gC)                    # (MMA, MMA_M, MMA_N)

    # MMA registers
    tCrA = make_fragment_A(thr_mma, tCsA)                # (MMA, MMA_M, MMA_K)
    tCrB = make_fragment_B(thr_mma, tCsB)                # (MMA, MMA_N, MMA_K)
    tCrC = make_fragment_C(thr_mma, tCgC)                # (MMA, MMA_M, MMA_N)
    zeros!(tCrC)

    k_max = size(tAgA, 4)
    for k in 1:k_max
        cp_async_wait()
        sync_threads()

        # Copy from smem to rmem
        copyto!(tCrA, tCsA)
        copyto!(tCrB, tCsB)
        sync_threads()

        if k < k_max
            copyto!(copy_A, tAsA, view(tAgA, :, :, :, k+1))
            copyto!(copy_B, tBsB, view(tBgB, :, :, :, k+1))
        end

        @gc_preserve gemm!(mma_C, tCrC, tCrA, tCrB, tCrC)
    end

    copyto!(tCgC, tCrC)
    return nothing
end

Double Buffering

We can also overlap shared-to-register memory copies with MMA computation using a technique called double buffering.

This involves allocating two shared memory buffers: one for the current computation and one for prefetching the next tile. We asynchronously prefetch the next tile from global memory to the second shared memory buffer while the first is being used for computation.

matmuil

@views function matmul_kernel(A, sA_layout, copy_A,
                              B, sB_layout, copy_B,
                              C, mma_C)
    sA = MoYeSharedArray(eltype(A), sA_layout)        # (bM, bK, 2)
    sB = MoYeSharedArray(eltype(B), sB_layout)        # (bN, bK, 2)

    mA = MoYeArray(A)
    mB = MoYeArray(B)
    mC = MoYeArray(C)

    bM = size(sA_layout, 1)
    bN = size(sB_layout, 1)
    bK = size(sB_layout, 2)

    gA = @tile mA (bM, bK) (blockIdx().x, :)
    gB = @tile mB (bN, bK) (blockIdx().y, :)
    gC = @tile mC (bM, bN) (blockIdx().x, blockIdx().y)

    # Copy partition
    thr_copy_a = get_slice(copy_A, threadIdx().x)      
    tAgA = partition_S(thr_copy_a, gA)                 # (CPY, CPY_M, CPY_K, k)
    tAsA = partition_D(thr_copy_a, sA)                 # (CPY, CPY_M, CPY_K, 2)

    thr_copy_b = get_slice(copy_B, threadIdx().x)
    tBgB = partition_S(thr_copy_b, gB)                 # (CPY, CPY_N, CPY_K, k)
    tBsB = partition_D(thr_copy_b, sB)                 # (CPY, CPY_N, CPY_K, 2)

    # Copy gmem to smem for k_tile=1
    copyto!(copy_A, tAsA[:, :, :, 1], tAgA[:, :, :, _1])
    copyto!(copy_B, tBsB[:, :, :, 1], tBgB[:, :, :, _1])

    # MMA partition
    thr_mma = get_slice(mma_C, threadIdx().x)
    tCsA = partition_A(thr_mma, sA)                    # (MMA, MMA_M, MMA_K, 2)
    tCsB = partition_B(thr_mma, sB)                    # (MMA, MMA_M, MMA_K, 2)
    tCgC = partition_C(thr_mma, gC)                    # (MMA, MMA_M, MMA_N)

    # MMA registers
    tCrA = make_fragment_A(thr_mma, tCsA[:, :, :, _1])    # (MMA, MMA_M, MMA_K)
    tCrB = make_fragment_B(thr_mma, tCsB[:, :, :, _1])    # (MMA, MMA_N, MMA_K)
    tCrC = make_fragment_C(thr_mma, tCgC)                 # (MMA, MMA_M, MMA_N)
    zeros!(tCrC)

    cp_async_wait()
    sync_threads()

    # Copy smem to rmem for k_block=1
    smem_read = 1
    smem_write = 2
    tCsA_p = view(tCsA, :, :, :, smem_read)
    tCsB_p = view(tCsB, :, :, :, smem_read)
    copyto!(tCrA[:, :, 1], tCsA_p[:, :, _1])
    copyto!(tCrB[:, :, 1], tCsB_p[:, :, _1])

    k_tile_max = size(tAgA, 4)
    k_block_max = static_size(tCrA, 3)
    for k_tile in 1:k_tile_max
        @loopinfo unroll for k_block in _1:k_block_max
            k_block_next = k_block + 1 
            if k_block == k_block_max
                cp_async_wait()
                sync_threads()
                tCsA_p = view(tCsA, :, :, :, smem_read)
                tCsB_p = view(tCsB, :, :, :, smem_read)
                k_block_next = 1
            end
            
            copyto!(tCrA[:, :, k_block_next], tCsA_p[:, :, k_block_next])
            copyto!(tCrB[:, :, k_block_next], tCsB_p[:, :, k_block_next])   

            if k_block == _1 && k_tile<k_tile_max
                copyto!(copy_A, tAsA[:, :, :, smem_write], tAgA[:, :, :, k_tile+1])
                copyto!(copy_B, tBsB[:, :, :, smem_write], tBgB[:, :, :, k_tile+1])
                smem_read, smem_write = smem_write, smem_read
            end
            
            @gc_preserve gemm!(mma_C, tCrC, tCrA[:, :, k_block], tCrB[:, :, k_block], tCrC)
        end
    end

    copyto!(tCgC, tCrC)
    return nothing
end

function matmul(A, B, C)
    bM = _128
    bN = _128
    bK = _8
    
    sA_layout = make_layout((bM, bK, _2), (_1, bM + _2, (bM + _2) * bK))
    sB_layout = make_layout((bN, bK, _2), (_1, bN + _2, (bN + _2) * bK))

    TA = eltype(A)
    TB = eltype(B)
    TC = eltype(C)
	
    copy_A = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{Float64}, TA}(),
                             @Layout((32, 8)),
                             @Layout((2, 1)))
    copy_B = make_tiled_copy(CopyAtom{CPOP_ASYNC_CACHEALWAYS{Float64}, TB}(),
                             @Layout((32, 8)),
                             @Layout((2, 1)))

    mma_C = make_tiled_mma(UniversalFMA{TA,TB, TC}(), # MMA operation
                           @Layout((32, 8)))          # Atom layout

    threads = Int(size(mma_C))
    blocks = (cld(size(A, 1), bM), cld(size(B, 1), bN))

    @cuda threads=threads blocks=blocks matmul_kernel(A, sA_layout, copy_A,
                                                      B, sB_layout, copy_B,
                                                      C, mma_C)
end

function test()
    A =  CUDA.randn(Float32, 2048, 256)
    B =  CUDA.randn(Float32, 2048, 256)
    C =  CUDA.randn(Float32, 2048, 2048)
    matmul(A, B, C)
    CUDA.synchronize()
    @test C == A * B'
    CUDA.unsafe_free!(A)
    CUDA.unsafe_free!(B)
    CUDA.unsafe_free!(C)
end

test()