주제: DT-2 Model Parallel / Tensor Parallel
문제 설정
Data Parallel은 데이터를 GPU마다 나누어 처리합니다. 하지만 모델 자체가 너무 커지면 하나의 GPU 메모리에 모델을 올릴 수 없습니다.
이 문제를 해결하기 위해 사용하는 방법이 Model Parallel입니다.
대표적인 방식:
- Pipeline Model Parallel
- Tensor Parallel
핵심 아이디어는 모델 파라미터 자체를 여러 GPU에 나누는 것입니다.
직관 비유
- Data Parallel -> 여러 사람이 같은 문제집을 풀기
- Model Parallel -> 문제를 여러 사람에게 나눠 풀기
1. Model Parallel 기본 개념
모델을 여러 GPU에 나누어 배치합니다.
Layer1 -> GPU0
Layer2 -> GPU1
Layer3 -> GPU2
forward 계산:
x -> GPU0 -> GPU1 -> GPU2
기호 의미
- x : 입력
- f1, f2 : 레이어 함수
수식:
왜 필요한가
모델이 GPU 메모리를 초과할 때 학습 가능하게 만듭니다.
주의점
레이어 간 데이터 전달이 필요하여 통신 비용이 발생합니다.
2. Tensor Parallel
Tensor Parallel은 하나의 레이어 연산을 여러 GPU로 분할하는 방법입니다.
대표적인 예는 Transformer의 linear layer입니다.
여기서 W를 GPU마다 나눕니다.
Column Parallel
가중치 행렬을 열 기준으로 분할합니다.
W = [W1 | W2]
각 GPU 계산:
결과:
왜 필요한가
큰 행렬곱을 여러 GPU로 나눠 계산할 수 있습니다.
주의점
출력 결합 과정에서 통신이 필요합니다.
3. Row Parallel
가중치 행렬을 행 기준으로 분할합니다.
W =
[W1
W2]
계산:
결과:
기호 의미
- W1, W2 : 파라미터 분할
왜 필요한가
메모리와 계산을 GPU마다 분산합니다.
4. Transformer Tensor Parallel 예
Transformer FFN:
Tensor parallel 적용:
GPU0 : W1_0
GPU1 : W1_1
연산:
GPU0 -> W1_0 x
GPU1 -> W1_1 x
결과를 concat합니다.
5. 통신 비용
Model Parallel에서는 GPU 간 데이터 전달이 필요합니다.
대표적인 통신 연산:
- All-Reduce
- All-Gather
- Reduce-Scatter
수식 예
또는
왜 중요한가
통신 비용이 병목이 될 수 있습니다.
주의점
GPU 간 네트워크 속도가 학습 성능에 큰 영향을 줍니다.
6. Data Parallel vs Tensor Parallel
| 방법 | 분할 대상 | 목적 |
|---|---|---|
| Data Parallel | 데이터 | 속도 향상 |
| Tensor Parallel | 파라미터 | 메모리 확장 |
코드-수식 연결
| 개념 | 구현 예 | 설명 |
|---|---|---|
| Tensor Parallel | Megatron-LM |
LLM tensor parallel 구현 |
| Model Parallel | torch.distributed |
분산 연산 |
| All-Reduce | torch.distributed.all_reduce() |
gradient 동기화 |
자주 하는 오해 5개
- Model Parallel은 Data Parallel의 단순 확장이라고 생각한다
- Tensor Parallel은 GPU 통신이 필요 없다고 생각한다
- 모델을 나누면 항상 속도가 빨라진다고 생각한다
- Tensor Parallel과 Pipeline Parallel을 같은 방식이라고 생각한다
- LLM은 Data Parallel만으로 학습 가능하다고 생각한다
체크리스트 (스스로 설명 가능해야 하는 질문)
- Model Parallel이 필요한 이유는 무엇인가?
- Tensor Parallel은 어떤 방식으로 행렬을 분할하는가?
- Column parallel과 Row parallel의 차이는 무엇인가?
- Tensor parallel에서 통신이 필요한 이유는 무엇인가?
- Data Parallel과 Tensor Parallel의 차이는 무엇인가?