주제: DT-4 ZeRO / FSDP (Parameter, Gradient, Optimizer Sharding)

분류: foundations

타입: concept

난이도: 중상

선수지식: 있음 — DDP, GPU 메모리 구조, optimizer

문제 설정

대형 모델 학습에서 가장 큰 문제는 GPU 메모리 사용량입니다. 일반적인 DDP에서는 각 GPU가 다음을 모두 복제합니다.

즉 GPU마다 동일한 데이터가 저장됩니다.

Memoryparameters+gradients+optimizerstates

대형 모델에서는 이 메모리 사용량이 매우 커집니다.

이를 해결하기 위해 등장한 방법이 ZeRO (Zero Redundancy Optimizer)FSDP (Fully Sharded Data Parallel)입니다.

직관 비유

1. DDP 메모리 구조

예: 4 GPU

GPU0 : parameters + gradients + optimizer
GPU1 : parameters + gradients + optimizer
GPU2 : parameters + gradients + optimizer
GPU3 : parameters + gradients + optimizer

모든 GPU가 동일한 데이터를 복제합니다.

문제

메모리 사용량이 GPU 수와 무관하게 동일합니다.

2. ZeRO 기본 아이디어

ZeRO는 모델 상태를 여러 GPU에 분산 저장합니다.

redundancy 제거가 핵심입니다.

모델 상태:

3. ZeRO Stage 1

optimizer state만 분산 저장합니다.

GPU0 : parameters + gradients + optimizer_part0
GPU1 : parameters + gradients + optimizer_part1
GPU2 : parameters + gradients + optimizer_part2
GPU3 : parameters + gradients + optimizer_part3

효과

optimizer 메모리 사용량 감소

주의점

parameters와 gradients는 여전히 복제됩니다.

4. ZeRO Stage 2

optimizer state와 gradient를 모두 분산 저장합니다.

GPU0 : parameters + gradient_part0 + optimizer_part0
GPU1 : parameters + gradient_part1 + optimizer_part1
GPU2 : parameters + gradient_part2 + optimizer_part2
GPU3 : parameters + gradient_part3 + optimizer_part3

효과

gradient 메모리 사용량 감소

주의점

forward/backward 시 gradient 통신이 필요합니다.

5. ZeRO Stage 3

parameters까지 분산 저장합니다.

GPU0 : parameter_part0 + gradient_part0 + optimizer_part0
GPU1 : parameter_part1 + gradient_part1 + optimizer_part1
GPU2 : parameter_part2 + gradient_part2 + optimizer_part2
GPU3 : parameter_part3 + gradient_part3 + optimizer_part3

효과

모델 파라미터 메모리도 분산됩니다.

주의점

forward 계산 전에 parameter gather가 필요합니다.

6. FSDP

FSDP는 PyTorch에서 제공하는 ZeRO Stage3와 유사한 방식입니다.

핵심 개념:

forward 과정:

1. parameter all-gather
2. forward 계산
3. backward
4. gradient reduce-scatter

수식 예

y=Wx

W가 GPU마다 분할됩니다.

W=concat(W1,W2,...,Wn)

왜 필요한가

수백억~수천억 파라미터 모델 학습이 가능해집니다.

7. ZeRO Stage 비교

Stage 분산 대상
Stage 1 optimizer state
Stage 2 optimizer + gradient
Stage 3 optimizer + gradient + parameter

8. 메모리 절약 효과

방법 메모리 절약
DDP 없음
ZeRO Stage1 optimizer 절약
ZeRO Stage2 gradient 절약
ZeRO Stage3 / FSDP parameter 절약

코드-수식 연결

개념 PyTorch 코드 설명
FSDP torch.distributed.fsdp.FullyShardedDataParallel parameter sharding
ZeRO DeepSpeed ZeRO optimizer/gradient 분산
all-gather torch.distributed.all_gather() parameter 수집
reduce-scatter torch.distributed.reduce_scatter() gradient 분산

자주 하는 오해 5개

체크리스트 (스스로 설명 가능해야 하는 질문)