(WIP) CheatSheet for Training Transformer (FLOPs, Time/Space Complexity)
03 Sep 2023< 목차 >
tmp
이번 post에서 정리하고 싶은 내용은 다음과 같다.
- Transformer의 초당 부동 소수점 연산 횟수 FLOPs (FLoating Point Operations per second)
- Time Complextiy
- Transformer의 Forward Pass 연산 시 Module별로 걸리는 소요 시간
- Transformer의 Backward Pass 연산 시 Module별로 걸리는 소요 시간
- Space Complexity
- Transformer의 Parameter를 Loading할 시 Precision에 따라 필요한 총 Memory
- Transformer의 Forward Pass 연산 시 module별로 필요한 Memory
- Transformer의 Backward Pass 연산 시 module별로 필요한 Memory
- Precision
- float32 (fp32): weight 하나 당 4 bytes
- float16 (fp16): weight 하나 당 2 bytes
- brain float16 (bf16): weight 하나 당 2 bytes
-
asd
- number of layers: \(n_{layer}\)
- dimension of the residual stream: \(d_{model}\)
- dimension of the intermediate feed-forward layer: \(d_{ff}\)
- dimension of the attention output: \(d_{attn}\)
-
number of attention heads per layer: \(n_{heads}\)
-
input context tokens: \(n_{ctx}\)
- model size (the number of non-embedding parameters): \(N \approx 2 d_{model} n_{layer} (2 d_{attn}) + d_{ff})\)
- also \(N \approx 12 n_{layer} d^2_{model}\) with the standard \(d_{attn} = d_{ff}/4 = d_{model}\)
- excluding biases and other sub-leading terms such as nonlinearities, layer norm and so on
Fig.