(WIP) CheatSheet for Training Transformer (FLOPs, Time/Space Complexity)


< 목차 >


tmp

이번 post에서 정리하고 싶은 내용은 다음과 같다.

  • Transformer의 초당 부동 소수점 연산 횟수 FLOPs (FLoating Point Operations per second)
  • Time Complextiy
    • Transformer의 Forward Pass 연산 시 Module별로 걸리는 소요 시간
    • Transformer의 Backward Pass 연산 시 Module별로 걸리는 소요 시간
  • Space Complexity
    • Transformer의 Parameter를 Loading할 시 Precision에 따라 필요한 총 Memory
    • Transformer의 Forward Pass 연산 시 module별로 필요한 Memory
    • Transformer의 Backward Pass 연산 시 module별로 필요한 Memory
  • Precision
    • float32 (fp32): weight 하나 당 4 bytes
    • float16 (fp16): weight 하나 당 2 bytes
    • brain float16 (bf16): weight 하나 당 2 bytes
  • asd

  • number of layers: \(n_{layer}\)
  • dimension of the residual stream: \(d_{model}\)
  • dimension of the intermediate feed-forward layer: \(d_{ff}\)
  • dimension of the attention output: \(d_{attn}\)
  • number of attention heads per layer: \(n_{heads}\)

  • input context tokens: \(n_{ctx}\)

  • model size (the number of non-embedding parameters): \(N \approx 2 d_{model} n_{layer} (2 d_{attn}) + d_{ff})\)
    • also \(N \approx 12 n_{layer} d^2_{model}\) with the standard \(d_{attn} = d_{ff}/4 = d_{model}\)
    • excluding biases and other sub-leading terms such as nonlinearities, layer norm and so on

scaling_laws_table1 Fig.

tmp

Reference