[기본기 0-5] 딥러닝을 위한 수치 안정성 기초
딥러닝은 매우 큰 수와 매우 작은 수를 동시에 계산합니다. 부동소수점 한계 때문에 overflow/underflow가 생기며, 특히 exp가 들어가는 softmax 계열에서 자주 문제가 발생합니다.
입문 단계에서는 모델 구조에 눈이 먼저 가지만, 실제 훈련에서는 숫자가 표현 범위를 넘는 순간 loss가 갑자기 nan이 되거나 gradient가 비정상적으로 바뀔 수 있다. 그래서 softmax, log, exp가 들어가는 식은 항상 "수학적으로 맞는가"와 함께 "숫자로 계산해도 안전한가"를 봐야 한다.
실무에서 프레임워크가 softmax와 log를 따로 계산하지 않고 묶어서 제공하는 이유도 여기 있다. 같은 수식을 더 안정한 형태로 재작성하면 overflow와 underflow를 크게 줄일 수 있기 때문이다.
시각 자료로 먼저 보기
왜 수치 안정성이 중요한가
- overflow: 값이 너무 커져 표현 범위를 넘음
- underflow: 값이 너무 작아 0으로 소실
수치 안정성이 깨지면 loss가 nan으로 터지거나 gradient가 비정상적으로 됩니다.
즉 수치 안정성은 성능 튜닝이 아니라 학습이 돌아가느냐의 문제다. 모델이 아무리 좋아도 계산이 불안정하면 업데이트 방향 자체가 망가져서 수렴하지 못한다. 그래서 안정한 계산 형태를 아는 것은 구현 실수를 줄이는 가장 직접적인 방법 중 하나다.
1) Overflow와 Underflow
exp(10) ≈ 22026
exp(100) ≈ 3.7e43
exp(1000) -> overflow
exp(-1000)-> 0 (underflow)
logits가 큰 분류 문제에서는 이 현상이 현실적으로 자주 발생합니다.
2) Softmax
x_i: 클래스 i의 logit- 출력: 0~1 범위, 전체 합 1
문제: x_i가 크면 e^{x_i} overflow 가능
softmax가 자주 예시로 등장하는 이유는, 분류 문제에서 거의 항상 쓰이는데도 지수함수 때문에 매우 불안정해지기 쉽기 때문이다. 그래서 실제 구현에서는 보통 가장 큰 logit을 먼저 빼서 모든 값을 상대적인 크기로 바꾼 뒤 계산한다.
3) 안정한 Softmax
예: x=[1000,1001,1002]이면 x-\max(x)=[-2,-1,0]으로 변환되어 안정적으로 계산됩니다.
중요: 분자/분모에 같은 상수를 적용하므로 최종 softmax 값은 동일합니다.
4) Log-Sum-Exp Trick
\log\sum_i e^{x_i}는 직접 계산 시 overflow 위험이 큽니다.
안정한 형태:
\log\sum_i e^{x_i}=m+\log\sum_i e^{x_i-m},\; m=\max(x)
핵심 아이디어: 큰 지수 계산을 안전한 구간으로 이동
직관 예시
x = [1000, 1001, 1002]
# 직접 exp는 overflow
# max를 빼면 [-2, -1, 0] -> exp(-2), exp(-1), exp(0) 로 안정
코드-수식 연결
| 수식 | PyTorch 코드 | 설명 |
|---|---|---|
softmax(x) | torch.softmax(x, dim=-1) | 확률 분포 계산 |
\log softmax(x) | torch.log_softmax(x, dim=-1) | 안정한 로그 확률 |
\log\sum e^x | torch.logsumexp(x, dim=-1) | log-sum-exp 안정 계산 |
x-\max(x) | x - x.max(dim=-1, keepdim=True).values | 수치 안정화 전처리 |
자주 하는 오해 5개
- softmax는 항상 안정적이라고 생각한다
- overflow는 매우 드문 현상이라고 생각한다
max(x)를 빼면 결과가 바뀐다고 생각한다- log-sum-exp를 단순 수학 장난으로 본다
log_softmax와softmax+log가 항상 동일한 수치 결과를 낸다고 생각한다
체크리스트
- overflow/underflow가 왜 생기는지 설명할 수 있는가?
- softmax에서 왜
max(x)를 빼는가? - log-sum-exp trick이 해결하는 문제를 말할 수 있는가?
- 왜
log_softmax가 더 안정적인가? - 수치 불안정이 학습에 어떤 증상(
nan, 발산)을 만드는지 설명할 수 있는가?