주제: DT-2 Model Parallel / Tensor Parallel

분류: foundations

타입: concept

난이도: 중상

선수지식: 있음 — DDP, 행렬곱, Transformer 구조

문제 설정

Data Parallel은 데이터를 GPU마다 나누어 처리합니다. 하지만 모델 자체가 너무 커지면 하나의 GPU 메모리에 모델을 올릴 수 없습니다.

이 문제를 해결하기 위해 사용하는 방법이 Model Parallel입니다.

대표적인 방식:

핵심 아이디어는 모델 파라미터 자체를 여러 GPU에 나누는 것입니다.

직관 비유

1. Model Parallel 기본 개념

모델을 여러 GPU에 나누어 배치합니다.

Layer1 -> GPU0
Layer2 -> GPU1
Layer3 -> GPU2

forward 계산:

x -> GPU0 -> GPU1 -> GPU2

기호 의미

수식:

y=f3(f2(f1(x)))

왜 필요한가

모델이 GPU 메모리를 초과할 때 학습 가능하게 만듭니다.

주의점

레이어 간 데이터 전달이 필요하여 통신 비용이 발생합니다.

2. Tensor Parallel

Tensor Parallel은 하나의 레이어 연산을 여러 GPU로 분할하는 방법입니다.

대표적인 예는 Transformer의 linear layer입니다.

y=Wx

여기서 W를 GPU마다 나눕니다.

Column Parallel

가중치 행렬을 열 기준으로 분할합니다.

W = [W1 | W2]

각 GPU 계산:

y1=W1x
y2=W2x

결과:

y=concat(y1,y2)

왜 필요한가

큰 행렬곱을 여러 GPU로 나눠 계산할 수 있습니다.

주의점

출력 결합 과정에서 통신이 필요합니다.

3. Row Parallel

가중치 행렬을 행 기준으로 분할합니다.

W =
[W1
 W2]

계산:

y1=W1x
y2=W2x

결과:

y=y1+y2

기호 의미

왜 필요한가

메모리와 계산을 GPU마다 분산합니다.

4. Transformer Tensor Parallel 예

Transformer FFN:

FFN(x)=W2sigma(W1x)

Tensor parallel 적용:

GPU0 : W1_0
GPU1 : W1_1

연산:

GPU0 -> W1_0 x
GPU1 -> W1_1 x

결과를 concat합니다.

5. 통신 비용

Model Parallel에서는 GPU 간 데이터 전달이 필요합니다.

대표적인 통신 연산:

수식 예

y=concat(y1,y2)

또는

y=Σyi

왜 중요한가

통신 비용이 병목이 될 수 있습니다.

주의점

GPU 간 네트워크 속도가 학습 성능에 큰 영향을 줍니다.

6. Data Parallel vs Tensor Parallel

방법 분할 대상 목적
Data Parallel 데이터 속도 향상
Tensor Parallel 파라미터 메모리 확장

코드-수식 연결

개념 구현 예 설명
Tensor Parallel Megatron-LM LLM tensor parallel 구현
Model Parallel torch.distributed 분산 연산
All-Reduce torch.distributed.all_reduce() gradient 동기화

자주 하는 오해 5개

체크리스트 (스스로 설명 가능해야 하는 질문)