주제: DT-4 ZeRO / FSDP (Parameter, Gradient, Optimizer Sharding)
문제 설정
대형 모델 학습에서 가장 큰 문제는 GPU 메모리 사용량입니다. 일반적인 DDP에서는 각 GPU가 다음을 모두 복제합니다.
- 모델 파라미터 (parameters)
- gradient
- optimizer state
즉 GPU마다 동일한 데이터가 저장됩니다.
대형 모델에서는 이 메모리 사용량이 매우 커집니다.
이를 해결하기 위해 등장한 방법이 ZeRO (Zero Redundancy Optimizer)와 FSDP (Fully Sharded Data Parallel)입니다.
직관 비유
- DDP -> 모든 사람이 같은 책을 모두 들고 있음
- ZeRO/FSDP -> 책을 여러 사람에게 나눠서 들고 있음
1. DDP 메모리 구조
예: 4 GPU
GPU0 : parameters + gradients + optimizer
GPU1 : parameters + gradients + optimizer
GPU2 : parameters + gradients + optimizer
GPU3 : parameters + gradients + optimizer
모든 GPU가 동일한 데이터를 복제합니다.
문제
메모리 사용량이 GPU 수와 무관하게 동일합니다.
2. ZeRO 기본 아이디어
ZeRO는 모델 상태를 여러 GPU에 분산 저장합니다.
즉 redundancy 제거가 핵심입니다.
모델 상태:
- parameters
- gradients
- optimizer states
3. ZeRO Stage 1
optimizer state만 분산 저장합니다.
GPU0 : parameters + gradients + optimizer_part0
GPU1 : parameters + gradients + optimizer_part1
GPU2 : parameters + gradients + optimizer_part2
GPU3 : parameters + gradients + optimizer_part3
효과
optimizer 메모리 사용량 감소
주의점
parameters와 gradients는 여전히 복제됩니다.
4. ZeRO Stage 2
optimizer state와 gradient를 모두 분산 저장합니다.
GPU0 : parameters + gradient_part0 + optimizer_part0
GPU1 : parameters + gradient_part1 + optimizer_part1
GPU2 : parameters + gradient_part2 + optimizer_part2
GPU3 : parameters + gradient_part3 + optimizer_part3
효과
gradient 메모리 사용량 감소
주의점
forward/backward 시 gradient 통신이 필요합니다.
5. ZeRO Stage 3
parameters까지 분산 저장합니다.
GPU0 : parameter_part0 + gradient_part0 + optimizer_part0
GPU1 : parameter_part1 + gradient_part1 + optimizer_part1
GPU2 : parameter_part2 + gradient_part2 + optimizer_part2
GPU3 : parameter_part3 + gradient_part3 + optimizer_part3
효과
모델 파라미터 메모리도 분산됩니다.
주의점
forward 계산 전에 parameter gather가 필요합니다.
6. FSDP
FSDP는 PyTorch에서 제공하는 ZeRO Stage3와 유사한 방식입니다.
핵심 개념:
- parameter sharding
- forward 시 parameter gather
- backward 후 parameter shard 유지
forward 과정:
1. parameter all-gather
2. forward 계산
3. backward
4. gradient reduce-scatter
수식 예
W가 GPU마다 분할됩니다.
왜 필요한가
수백억~수천억 파라미터 모델 학습이 가능해집니다.
7. ZeRO Stage 비교
| Stage | 분산 대상 |
|---|---|
| Stage 1 | optimizer state |
| Stage 2 | optimizer + gradient |
| Stage 3 | optimizer + gradient + parameter |
8. 메모리 절약 효과
| 방법 | 메모리 절약 |
|---|---|
| DDP | 없음 |
| ZeRO Stage1 | optimizer 절약 |
| ZeRO Stage2 | gradient 절약 |
| ZeRO Stage3 / FSDP | parameter 절약 |
코드-수식 연결
| 개념 | PyTorch 코드 | 설명 |
|---|---|---|
| FSDP | torch.distributed.fsdp.FullyShardedDataParallel |
parameter sharding |
| ZeRO | DeepSpeed ZeRO |
optimizer/gradient 분산 |
| all-gather | torch.distributed.all_gather() |
parameter 수집 |
| reduce-scatter | torch.distributed.reduce_scatter() |
gradient 분산 |
자주 하는 오해 5개
- ZeRO는 단순한 optimizer라고 생각한다
- DDP보다 항상 빠르다고 생각한다
- ZeRO Stage3는 통신이 필요 없다고 생각한다
- FSDP는 Data Parallel과 동일하다고 생각한다
- 모든 모델에서 ZeRO가 필요하다고 생각한다
체크리스트 (스스로 설명 가능해야 하는 질문)
- DDP에서 메모리 복제가 발생하는 이유는 무엇인가?
- ZeRO Stage1,2,3의 차이는 무엇인가?
- parameter sharding이 왜 필요한가?
- FSDP에서 forward 전에 parameter gather가 필요한 이유는 무엇인가?
- ZeRO와 Tensor Parallel의 차이는 무엇인가?