주제: IO-4 FlashAttention (Memory IO 최적화 기반 Attention 가속)
문제 설정
Transformer의 Self-Attention은 계산 자체보다 GPU 메모리 접근(IO)이 병목이 되는 경우가 많습니다.
기본 attention 식:
Attention(Q,K,V) = softmax(QKT/√d_k)V
문제는 attention 계산 과정에서 큰 행렬이 여러 번 GPU 메모리(HBM)로 이동한다는 점입니다.
- QKᵀ 계산
- softmax 계산
- V 곱
이 과정에서 메모리 IO 비용이 매우 큽니다.
직관 비유
- 일반 attention -> 데이터를 계속 창고(HBM)에서 꺼내 사용
- FlashAttention -> 작은 작업 단위로 나눠 창고 접근 최소화
핵심 아이디어:
Attention 계산을 타일(block) 단위로 수행하여 메모리 접근을 최소화
1. Standard Attention 메모리 흐름
일반적인 attention 계산 과정:
1. S = QKᵀ
2. P = softmax(S)
3. O = PV
수식
S = QKT
기호 의미
- Q : query matrix
- K : key matrix
- V : value matrix
- S : attention score
- P : attention weight
주의점
행렬 S는 크기가 매우 큽니다.
문제
S를 GPU 메모리에 저장하고 다시 읽어야 합니다.
2. Memory IO 병목
GPU 성능은 다음 두 요소에 의해 제한됩니다.
- compute throughput
- memory bandwidth
attention은 종종 다음 조건을 가집니다.
memory-bound workload
즉:
왜 중요한가
계산 자체보다 데이터 이동이 더 느립니다.
3. FlashAttention 핵심 아이디어
FlashAttention은 attention을 block 단위로 계산합니다.
Q block
K block
V block
작은 타일 단위로 계산하여 GPU의 SRAM(shared memory)을 활용합니다.
수식 관점
attention을 다음과 같이 분할합니다.
O_i = Σ softmax(Q_i K_jT) V_j
기호 의미
- i : query block
- j : key block
왜 필요한가
큰 행렬을 메모리에 저장하지 않아도 됩니다.
4. Online Softmax
FlashAttention은 softmax를 online 방식으로 계산합니다.
softmax:
FlashAttention에서는 다음 값을 유지합니다.
- running max
- running sum
수식
왜 필요한가
전체 attention matrix를 저장하지 않고도 softmax 계산이 가능합니다.
5. 메모리 복잡도
| 방법 | 메모리 사용 |
|---|---|
| Standard attention | O(n²) |
| FlashAttention | O(n) |
기호 의미
- n : sequence length
왜 중요한가
긴 context에서도 메모리 사용을 크게 줄입니다.
6. 속도 향상
FlashAttention은 다음 효과를 가집니다.
- HBM 접근 감소
- SRAM 활용
- GPU 효율 증가
실제 시스템에서:
2~4배 속도 향상
7. FlashAttention 특징
| 특징 | 설명 |
|---|---|
| exact attention | 근사 아님 |
| IO-aware | 메모리 접근 최적화 |
| block algorithm | tile 기반 계산 |
8. 실제 적용
FlashAttention은 다음 시스템에서 사용됩니다.
- GPT 계열 모델
- LLaMA
- LLM inference frameworks
특히 긴 context 모델에서 중요합니다.
코드-수식 연결
| 개념 | 코드 | 설명 |
|---|---|---|
| flash attention | flash_attn.flash_attn_func |
FlashAttention 구현 |
| scaled dot-product | torch.nn.functional.scaled_dot_product_attention |
attention 계산 |
| softmax | torch.softmax() |
attention weight 계산 |
자주 하는 오해 5개
- FlashAttention은 근사 attention이라고 생각한다
- FlashAttention은 계산 복잡도를 줄인다고 생각한다
- GPU compute가 병목이라고 생각한다
- attention matrix가 완전히 사라진다고 생각한다
- FlashAttention은 inference에서만 사용된다고 생각한다
체크리스트 (스스로 설명 가능해야 하는 질문)
- Transformer attention에서 병목이 되는 요소는 무엇인가?
- FlashAttention이 해결하는 문제는 무엇인가?
- 왜 attention 계산은 memory-bound workload인가?
- FlashAttention의 block algorithm은 어떤 원리인가?
- FlashAttention이 메모리 사용을 줄이는 이유는 무엇인가?