주제: IO-4 FlashAttention (Memory IO 최적화 기반 Attention 가속)

분류: llm

타입: concept

난이도: 중상

선수지식: 있음 — Self-Attention, GPU 메모리 구조, 행렬곱

문제 설정

Transformer의 Self-Attention은 계산 자체보다 GPU 메모리 접근(IO)이 병목이 되는 경우가 많습니다.

기본 attention 식:

Attention(Q,K,V) = softmax(QKT/√d_k)V

문제는 attention 계산 과정에서 큰 행렬이 여러 번 GPU 메모리(HBM)로 이동한다는 점입니다.

이 과정에서 메모리 IO 비용이 매우 큽니다.

직관 비유

핵심 아이디어:

Attention 계산을 타일(block) 단위로 수행하여 메모리 접근을 최소화

1. Standard Attention 메모리 흐름

일반적인 attention 계산 과정:

1. S = QKᵀ
2. P = softmax(S)
3. O = PV

수식

S = QKT

P=softmax(S)
O=PV

기호 의미

주의점

행렬 S는 크기가 매우 큽니다.

SRn×n

문제

S를 GPU 메모리에 저장하고 다시 읽어야 합니다.

2. Memory IO 병목

GPU 성능은 다음 두 요소에 의해 제한됩니다.

attention은 종종 다음 조건을 가집니다.

memory-bound workload

즉:

RuntimeMemoryIO

왜 중요한가

계산 자체보다 데이터 이동이 더 느립니다.

3. FlashAttention 핵심 아이디어

FlashAttention은 attention을 block 단위로 계산합니다.

Q block
K block
V block

작은 타일 단위로 계산하여 GPU의 SRAM(shared memory)을 활용합니다.

수식 관점

attention을 다음과 같이 분할합니다.

O_i = Σ softmax(Q_i K_jT) V_j

기호 의미

왜 필요한가

큰 행렬을 메모리에 저장하지 않아도 됩니다.

4. Online Softmax

FlashAttention은 softmax를 online 방식으로 계산합니다.

softmax:

softmax(xi)=exp(xi)/Σexp(xj)

FlashAttention에서는 다음 값을 유지합니다.

수식

mi=max(score)
li=Σexp(scoremi)

왜 필요한가

전체 attention matrix를 저장하지 않고도 softmax 계산이 가능합니다.

5. 메모리 복잡도

방법 메모리 사용
Standard attention O(n²)
FlashAttention O(n)

기호 의미

왜 중요한가

긴 context에서도 메모리 사용을 크게 줄입니다.

6. 속도 향상

FlashAttention은 다음 효과를 가집니다.

실제 시스템에서:

2~4배 속도 향상

7. FlashAttention 특징

특징 설명
exact attention 근사 아님
IO-aware 메모리 접근 최적화
block algorithm tile 기반 계산

8. 실제 적용

FlashAttention은 다음 시스템에서 사용됩니다.

특히 긴 context 모델에서 중요합니다.

코드-수식 연결

개념 코드 설명
flash attention flash_attn.flash_attn_func FlashAttention 구현
scaled dot-product torch.nn.functional.scaled_dot_product_attention attention 계산
softmax torch.softmax() attention weight 계산

자주 하는 오해 5개

체크리스트 (스스로 설명 가능해야 하는 질문)