주제: IO-2 KV Cache (Transformer 디코딩 최적화)
언어 모델은 토큰을 하나씩 생성하므로, 순진하게 구현하면 새 토큰이 하나 나올 때마다 이전 전체 문장에 대한 attention을 처음부터 다시 계산하게 된다. 이건 계산량 낭비가 매우 크다.
문제 설정
LLM은 텍스트를 생성할 때 토큰을 하나씩 생성합니다.
언어 모델 목표:
즉 새 토큰을 생성할 때마다 이전 토큰 전체를 다시 고려해야 합니다.
문제는 naive 구현에서는 매 step마다 attention을 처음부터 다시 계산해야 한다는 점입니다.
KV Cache는 바로 이 낭비를 줄인다. 과거 토큰들의 key와 value는 이미 계산이 끝났으므로, 다음 step에서는 새 토큰에 해당하는 query만 계산하고 기존 cache를 재사용하면 된다. 그래서 긴 문장 생성에서 속도 차이가 크게 난다.
직관 비유
- 매 토큰 생성 시 문장을 다시 읽는 것
- -> 매우 비효율적
이를 해결하기 위해 사용하는 것이 KV Cache입니다.
1. Self-Attention 계산
Transformer attention:
Attention(Q,K,V) = softmax(QKT/√d_k)V
기호 의미
- Q : query
- K : key
- V : value
왜 중요한가
디코딩 단계에서는 매 토큰마다 attention을 계산해야 합니다.
2. Naive 디코딩
토큰을 순차적으로 생성한다고 가정합니다.
Step1: x1
Step2: x1 x2
Step3: x1 x2 x3
Step4: x1 x2 x3 x4
각 step에서 attention 계산:
전체 비용:
기호 의미
- T : 생성 토큰 길이
주의점
긴 생성에서는 매우 비효율적입니다.
3. KV Cache 아이디어
각 토큰의 key/value를 미리 저장합니다.
기호 의미
- K_i : 토큰 i의 key
- V_i : 토큰 i의 value
왜 필요한가
이전 토큰의 attention 정보를 재사용할 수 있습니다.
4. KV Cache 디코딩
새 토큰 t에서 계산되는 값:
attention 계산:
직관 설명
- 이전 토큰의 K,V는 재사용
- 새 토큰의 Q만 계산
주의점
K,V는 계속 cache에 추가됩니다.
5. 계산 복잡도 비교
| 방법 | 계산 비용 |
|---|---|
| naive decoding | O(T³) |
| KV cache | O(T²) |
KV cache는 계산량을 크게 줄입니다.
6. KV Cache 메모리
KV cache 메모리는 다음과 같이 증가합니다.
기호 의미
- layers : Transformer layer 수
- sequence : 토큰 길이
- d_head : head dimension
왜 중요한가
긴 context는 KV cache 메모리를 증가시킵니다.
7. KV Cache 업데이트
디코딩 과정:
Step1
K_cache = [K1]
Step2
K_cache = [K1,K2]
Step3
K_cache = [K1,K2,K3]
V cache도 동일하게 확장됩니다.
8. KV Cache 효과
- 디코딩 속도 증가
- 중복 attention 계산 제거
- LLM inference 필수 기술
코드-수식 연결
| 개념 | 코드 | 설명 |
|---|---|---|
| KV cache | past_key_values |
attention cache |
| attention 계산 | model(..., past_key_values=cache) |
cache 재사용 |
| token 생성 | model.generate() |
자동 decode loop |
자주 하는 오해 5개
- KV cache는 training에서도 사용된다고 생각한다
- KV cache는 선택 사항이라고 생각한다
- KV cache가 latency에 영향을 주지 않는다고 생각한다
- attention을 완전히 제거한다고 생각한다
- KV cache는 메모리를 사용하지 않는다고 생각한다
체크리스트 (스스로 설명 가능해야 하는 질문)
- KV cache는 어떤 문제를 해결하는가?
- 왜 naive decoding은 비효율적인가?
- KV cache는 어떤 값을 저장하는가?
- KV cache 사용 시 계산량은 어떻게 감소하는가?
- KV cache가 메모리에 미치는 영향은 무엇인가?