FlashAttention.jl

FlashAttention

Stable Dev Build Status Coverage

This is a Julia implementation of the Flash Attention algorithm.

Usage

using FlashAttention, CUDA

Q = CUDA.randn(Float16, 64, 1024, 48, 3);
K = CUDA.randn(Float16, 64, 1024, 48, 3);
V = CUDA.randn(Float16, 64, 1024, 48, 3);

flash_attention(Q,K,V)

Profiling

Please refer to the file flash_attention.ncu-rep. This is not the fastest implementation for 1) we do not use tensor cores as in the C++ implmentation, 2) CUDA.jl doese not yet support asynchronous copy from global memory to shared memory, and 3) this kernel’s theoretical occupancy (12.5%) is limited by the required amount of shared memory.

Future work

I plan to implement it in the future using MoYe.jl to achieve competitive performance.