What and Why Training Dynamics?


< 목차 >


Deep Learning (DL)분야에서의 Dynamics란 무엇인가?

요즘 Large Language Model (LLM)같은 거대한 Neural Networks (NNs)을 training 하는 것이 key role이 되었고, 나도 요즘 수십 biliion scale model을 pre-training하는 일을 하다보니 training dynamics라는 얘기를 많이 듣는다. Large NN은 조금만 수가 틀려도 모델이 발산할 수 있다. 라고 paper들에선 겁을 주지만 실제로는 optimization에서 어떤 일이 일어나는지 알 수 없기도 하고, 사실 open source framework들이 이미 tranformer같은 architecture design를 잘 만들어놨고 initialization 등의 안전장치를 해놔서 그렇게 잘 터지진 않긴 한다.

그럼에도 불구하고 Greg Yang의 Tensor Programs series같은 것들을 읽어보면 NN의 width, depth가 커질 때 (크기가 커질 때) 어떤식으로 learning rate, initialization 등을 설정해줘야 하는지, 즉 parameterization을 해줘야 하는지에 따라 결과가 크게 바뀌고 예측이 언정도 가능할지에 대해 생각해보게 되며, 이런 것들을 정립하는 것이 꽤 중요한 일임을 이해할 수 있다.

한 번 돌리면 두 세달, 길게는 반년 일년씩 걸리는 LLM training을 안전하게 끝마치기 위해 (training stability), 그리고 성능을 예측하기 위해 (scaling law prediction) training dynamics 라는 키워드로 paper가 계속 나오고 있는데, 막상 DL dynamics의 정의에 대해 찾아보려고 해도 wiki같은게 딱히 없기 때문에 대충 내 생각을 적어보려고 한다.

DL Theory라는 분야가 있고 그 안에 training dynamics가 있는 것으로 보이는데, 이는 보통 optimization dynamics를 의미한다. 우리가 교육과정에서 Physics 같은데서 말 그대로의 역학 (dynamics)을 배우게 되는데, 이 관점에서 생각해보도록 하자. Wiki를 찾아보면 역학(力學, 문화어: 력학, 영어: Mechanics)은 Physics의 한 분야로, 외력을 받고 있는 물질의 정지 또는 운동 상태를 설명하고 예측하는 자연 과학이라고 되어있다. 예를 들어 중력이 작용해 뉴턴의 머리에 사과가 떨어짐으로써 뉴턴이 \(F=ma\)같은 제2 법칙등을 정의함으로써 고전 역학을 정립했다고 할 수 있는데, 위치 에너지 (potential energy) 같은 개념에 따르면 \(E_p = mgh\)같이 에너지에 따라서 기준 면에서부터의 물체의 위치를 예측할 수 있다.

이 개념을 대입하면 DL분야에서의 training dyanmics란 SGD, Adam optimizer등을 통해 NN을 training하면 특정 optimization timestep, \(t\)에서 model이 어떻게 되는가?를 예측하는 것이 된다. 즉 training trajectory를 예측하는 것을 말한다.

왜 training dyanmics라는 분야에 대해 gpt-o1에 물어본 결과는 다음과 같다. (chatgpt의 결과를 맹신해서는 안되지만 요즘 모델이 뱉는 답변들을 보면 깜짝 놀랄정도로 퀄리티가 좋은 것 같다)

  • Why Training Dynamics?
    • training 과정의 이해: training dynamics를 연구하면 NN의 parameter와 activation 값이 시간에 따라 어떻게 변화하는지 파악할 수 있다. 이를 통해 모델이 어떻게 패턴을 training하고 일반화하는지 이해할 수 있다.
      • IMO, 이를 정확히 예측하는 것이 아니더라도, activation function의 형태, layernorm 등을 추가하는 것에 따라 activation, gradient가 어떻게 변하는지에 대해 분석하고 이를 개선하는 것 등도 포함
    • 일반화 성능 향상: training 동역학을 분석하면 모델이 overfitting이나 underfitting에 빠지는 원인을 찾을 수 있다. 이를 기반으로 일반화 능력을 향상시키는 규제 기법이나 optimization 방법을 개발할 수 있다.
    • optimization 과정 개선: training dynamics는 optimization algorithm의 성능과 한계를 이해하는 데 도움이 됩니다. 예를 들어, learning rate 조정이나 모멘텀과 같은 hyperparameter가 training에 미치는 영향을 분석할 수 있다.
    • 이론적 기반 강화: DL은 실험적으로 성공을 거두었지만, 그 이면의 이론은 아직 완전히 이해되지 않은 부분이 많습니다. training dynamics를 연구함으로써 DL의 수학적이고 이론적인 이해를 심화시킬 수 있다.
    • 안정성과 해석 가능성: training 과정에서 발생하는 불안정성이나 예측 불가능한 행동을 파악하여 모델의 안정성을 높일 수 있다. 또한, 모델의 내부 작동 방식을 이해함으로써 해석 가능성을 향상시킬 수 있다.
  • Dynamics in Physics vs DL
    • 시스템의 상태 변화 분석:
      • Physics의 역학: 물체나 입자의 위치, 속도, 가속도 등의 물리적 상태가 시간에 따라 어떻게 변화하는지를 연구한다. 예를 들어, 뉴턴의 운동 법칙을 통해 힘과 가속도의 관계를 분석한다.
      • DL의 트레이닝 다이내믹스: NN의 parameter와 bias 등이 training 과정에서 어떻게 변화하는지를 분석한다. 이를 통해 모델이 데이터를 training하면서 내부 parameter가 어떻게 조정되는지 이해한다.
    • 수학적 모델링과 방정식:
      • Physics: 미분 방정식과 수학적 모델을 사용하여 시스템의 미래 상태를 예측한다. 예를 들어, 해밀토니안 역학이나 라그랑지안 역학을 통해 에너지와 운동의 관계를 해석한다.
      • DL: optimization algorithm(예: gradient descent)을 미분 방정식으로 표현하여 training 과정을 수학적으로 모델링한다. loss function의 기울기를 따라 parameter를 업데이트하는 과정이 이에 해당한다.
    • 에너지 개념의 활용:
      • Physics: 잠재 에너지와 운동 에너지 등 에너지 개념을 사용하여 시스템의 안정성과 평형 상태를 분석한다.
      • DL: loss function는 시스템의 ‘에너지’로 간주될 수 있으며, training의 목표는 이 에너지를 최소화하는 것입니다.
    • 안정성과 평형 상태:
      • Physics: 시스템이 평형 상태에 도달하거나 불안정한 상태로부터 어떻게 변화하는지 연구한다.
      • DL: training 과정에서 loss function이 최소점에 수렴하는지, 또는 지역 최소점이나 안장점에 머무르는지 분석한다.
    • 노이즈와 확률적 요소:
      • Physics: 통계 역학에서는 입자의 무작위 운동과 열역학적 특성을 확률적으로 분석한다.
      • DL: stochastic gradient descent (SGD) 등에서 배치 노이즈와 같은 확률적 요소가 training에 미치는 영향을 연구한다.
    • 복잡계와 비선형성:
      • Physics: 복잡한 상호 작용을 가진 시스템에서 카오스나 패턴 형성을 연구한다.
      • DL: 심층 NN의 nonlinear activation function를 통해 복잡한 데이터 패턴을 training한다.
  • Main Topics of Training Dynamics
    • training curve 분석: loss function이나 정확도가 training 단계에 따라 어떻게 변화하는지 관찰한다.
    • optimization 경로: parameter 공간에서 모델이 어떤 경로를 따라 이동하는지 분석한다.
    • hyperparameter 영향: training률, 모멘텀, 배치 크기 등의 hyperparameter가 training 동역학에 미치는 영향을 연구한다.
    • 평균장 이론 (Mean-field Theory; MFT): 큰 규모의 NN에서 parameter의 통계적 거동을 분석한다.