Why Data2Vec 2.0 Works Well?


< 목차 >


Motivation and Contribution

Data2vec (d2v) 시리즈는 Meta AI Research (FAIR) 의 논문으로 2022년 ICML 에 1.0 version 이 나왔고 같은해 12월에 2.0 version이 나왔습니다.

d2v (1.0) 의 motivation 은 Speech, Text, Vision 이라는 각기 다른 domain 에서 모두 working 하는 general 한 Self-Supervised Learning (SSL) 을 만들겠다는 것이었는데요, Text 쪽에서는 BERT, GPT 방식으로 학습하는게 일반적이고 음성에서는 이를 조금 변형해 Wav2Vec (w2v) 이라는 방식으로 학습하는 등 서로 다른 방식으로 학습하는 것이 종래에 이런 서로다른 mode 를 합친 multimodal 상황에 잘 맞지 않을 것이라는 이유 때문입니다.

주의할 점은 이 논문 자체에서는 Multimodal 학습을 한 실험결과는 없다는 것입니다. (마치 그 실험을 했을것 처럼 얘기했으나… 그래서 제대로 읽어보지 않은 분들 중에서는 이 논문에서 그럼 3개 modality 를 드디어 같이 학습한 것이 아니냐고 얘기하는 분들도 많았습니다.)

이 post의 최종 goal 인 d2v 2.0 (d2v2) 를 이해하기 위해서 알아얄 몇가지 key module 들이 있는데 이는 다음과 같습니다.

  • BYOL style Self Supervised Learning (SSL) Approach
    • Exponential Moving Everage (EMA) Teacher (Mean Teacher)
  • Masked Auto Encoder (MAE) Style Approach

Data2Vec

d2v_fig1 Fig.

BYOL Style Self-Supervised Learning (SSL)

byol_fig1 Fig.

byol_fig2 Fig.

Masking

Training Targets

Teacher Parameterization

\[\Delta \leftarrow \tau \Delta + (1-\tau) \theta\]

Network Architecture

d2v 1.0 버전의 네트워크 아키텍쳐는 아래와 같이 생겼습니다.

  • domain specific feature extractor
    • speech : 1d-cnn layers (downsampler)
      • num.layers : 7
      • 512 channels
      • strides (5,2,2,2,2,2,2)
      • kernel widths (10,3,3,3,3,2,2)
  • contextualized encdoer
    • for all : transformer encoder layers
      • positional encoding
        • speech : convolutional positional encodig
          • num.layers : 5
      • num.layers : 12 (base), 24 (large)
      • hiddem.dim : 768 (base), 1024 (large)
num. shared model params: 314,326,016 (num. trained: 314,326,016)
Data2VecAudioModel(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeChannel()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeChannel()
        )
        (3): GELU(approximate='none')
      )
      (1-4): 4 x Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeChannel()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeChannel()
        )
        (3): GELU(approximate='none')
      )
      (5-6): 2 x Sequential(
        (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeChannel()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeChannel()
        )
        (3): GELU(approximate='none')
      )
    )
  )
  (post_extract_proj): Linear(in_features=512, out_features=1024, bias=True)
  (dropout_input): Dropout(p=0.0, inplace=False)
  (dropout_features): Dropout(p=0.0, inplace=False)
  (encoder): TransformerEncoder(
    (pos_conv): Sequential(
      (0): Sequential(
        (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
        (1): SamePad()
        (2): TransposeLast()
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
        (4): TransposeLast()
        (5): GELU(approximate='none')
      )
      (1): Sequential(
        (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
        (1): SamePad()
        (2): TransposeLast()
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
        (4): TransposeLast()
        (5): GELU(approximate='none')
      )
      (2): Sequential(
        (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
        (1): SamePad()
        (2): TransposeLast()
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
        (4): TransposeLast()
        (5): GELU(approximate='none')
      )
      (3): Sequential(
        (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
        (1): SamePad()
        (2): TransposeLast()
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
        (4): TransposeLast()
        (5): GELU(approximate='none')
      )
      (4): Sequential(
        (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
        (1): SamePad()
        (2): TransposeLast()
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
        (4): TransposeLast()
        (5): GELU(approximate='none')
      )
    )
    (layers): ModuleList(
      (0-23): 24 x TransformerSentenceEncoderLayer(
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (final_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
  • teacher and student architectures are exactly same

Representation Collapse

Data2Vec 2.0

d2v2_fig1 Fig.

mae_fig1 Fig.

cs285_lec17_ae1) Fig.

cs285_lec17_dae1) Fig.

mae_fig2 Fig.

Target Representations and Learning Objective

Network Architecture (Asymmetric Encoder/Decoder Architecture)

d2v2 의 전체 모델 구조는 아래와 같습니다.

  • domain specific local_encoder
  • AutoEncoder like modules
    • encoder
      • feature_extractor
        • transformer encoder layers
          • positional encoding
            • speech : convolutional positional encodig
              • num.layers : 5
          • num.layers : 16 (large)
          • hidden.dim : 1024 (large)
          • add alibi bias for relative positional encoding
      • contextualized_encoder
        • transformer encoder layers
          • num.layers : 8 (large)
          • hidden.dim : 1024 (large)
          • add alibi bias for relative positional encoding
    • decoder
      • speech : 1-d conv decoder
        • num.layers : 4

우선 domain 마다 local_encdoer 가 있어서 raw data 에서 feature 를 뽑는 모듈이 하나 있습니다. 저의 경우 speech domain 을 주로 보기 때문에 이 기준으로 설명드리면 1d signal 를 입력으로 사용해 7 개의 서로 다른 kernel, stride size 를 갖는 1d conv layer 가 feature 를 추출합니다. 이 때 speech domain 의 local feature 들 간의 correlation 을 고려한 domain knowledge가 사용되었다고 생각할 수 있고, vision, nlp 에서도 각 domain 에 맞게 디자인된 모듈이 사용됩니다.

이제 MAE 구조를 따르는 d2v2 이기 때문에 크게 encoder / decoder 두개 구조가 있어야 되는데요, encoder 쪽을 보시면 transformer encoder 24층을 2가지로 구분짓습니다.

먼저 local encoder output 을 받아서 일부분을 masking 한 다음에 transformer block 16개를 통과시키는데요, 이를 feature_extracotr 라고 합니다. 이 떄는 speech domain 의 경우 w2v2 에서 처럼 relative positional encoding 정보를 주입해주는데요, local encoder 에서 처럼 conv layer 를 여러번 통과시키는 것이 이 역할을 하게 됩니다.

그리고 마지막으로 8층 정도 되는 context_encoder가 있어 이를 통과하면 모든 encoder를 통과하게 됩니다.

이 때 모든 transformer block 들에는 또 추가적으로 Alibi bias 를 QK attention map 에 더해주는데요, 이미 relative positional encoding 가 들어간 시점에서 꽤 redundant 한 것으로 보이나 실험 결과 성능에 중요한 영향을 끼쳤다고 합니다. Alibi 는 초기화 된 아래처럼 자기 자신과 가까운 부분에만 추가 점수를 주는, 즉 local feature 를 더 보라는 bias 를 주는데, 이를 학습하지 않고 head 별로 (head 가 large 는 16개 있음) learnable scalar 를 둬서 이를 학습하게 했습니다.

아래는 downstream finetuning 을 했을 때 head 별 alibi scalar 값 입니다.

Parameter containing:
tensor([[[[[ 5.9082e-01]],
          [[ 1.2866e-01]],
          [[ 3.2153e-01]],
          [[ 1.4374e-02]],
          [[ 3.9917e-01]],
          [[ 5.4346e-01]],
          [[ 4.3457e-01]],
          [[ 3.7598e-01]],
          [[ 2.4402e-01]],
          [[ 2.0401e-02]],
          [[ 2.9572e-02]],
          [[ 1.3199e-02]],
          [[ 2.1698e-02]],
          [[-2.1100e-05]],
          [[ 3.7346e-03]],
          [[ 1.3786e-02]]]]], device='cuda:5', dtype=torch.float16,
       requires_grad=True)

즉 어떤 head 는 글로벌 하게 보라는 의미죠.

d2v2_alibi Fig. head 를 Average 했고 visualize 했으므로 아무런 차이가 없어보인다.

여기서 아니 transformer encoder 를 굳이 16, 8층으로 나눠서 선언한 이유가 뭐지? 라는 생각이 드실 수 있습니다. 논문에서는 굳이 전체 24층의 encoder 를 2개로 구분지어 서로 다른 역할을 하게 한다는 언급자체는 없는데요, 코드를 살펴본 결과 사실 둘의 차이는 없다고 할 수 있겠습니다. 제 생각에는 d2v 시리즈가 여러 modality 의 입력을 joint 하게 SSL 하기 위해서 이를 나눠둔 것 같습니다.
그러니까 사실상 24층짜리 똑같은 alibi bias 가 들어간 transformer encoder block 인데 나중에 modality fusion 을 하는 부분이 16층 부근이어서 그랬던 것이죠.

이제 마지막으로 이를 다시 reconstruct masking 되기 전의 feature vector 들로 복원해줄 decoder 가 필요한데요, 이는 4개의 conv layer 로 구성되어 있습니다.

num. shared model params: 315,184,144 (num. trained: 315,184,144)
Data2VecMultiModel(
  (modality_encoders): ModuleDict(
    (AUDIO): AudioEncoder(
      (local_encoder): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Sequential(
              (0): TransposeChannel()
              (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (2): TransposeChannel()
            )
            (3): GELU(approximate='none')
          )
          (1-4): 4 x Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Sequential(
              (0): TransposeChannel()
              (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (2): TransposeChannel()
            )
            (3): GELU(approximate='none')
          )
          (5-6): 2 x Sequential(
            (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Sequential(
              (0): TransposeChannel()
              (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (2): TransposeChannel()
            )
            (3): GELU(approximate='none')
          )
        )
      )
      (project_features): Sequential(
        (0): TransposeLast()
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): Linear(in_features=512, out_features=1024, bias=True)
      )
      (relative_positional_encoder): Sequential(
        (0): TransposeLast()
        (1): Sequential(
          (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
          (1): SamePad()
          (2): TransposeLast()
          (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
          (4): TransposeLast()
          (5): GELU(approximate='none')
        )
        (2): Sequential(
          (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
          (1): SamePad()
          (2): TransposeLast()
          (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
          (4): TransposeLast()
          (5): GELU(approximate='none')
        )
        (3): Sequential(
          (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
          (1): SamePad()
          (2): TransposeLast()
          (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
          (4): TransposeLast()
          (5): GELU(approximate='none')
        )
        (4): Sequential(
          (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
          (1): SamePad()
          (2): TransposeLast()
          (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
          (4): TransposeLast()
          (5): GELU(approximate='none')
        )
        (5): Sequential(
          (0): Conv1d(1024, 1024, kernel_size=(19,), stride=(1,), padding=(9,), groups=16)
          (1): SamePad()
          (2): TransposeLast()
          (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
          (4): TransposeLast()
          (5): GELU(approximate='none')
        )
        (6): TransposeLast()
      )
      (context_encoder): BlockEncoder(
        (blocks): ModuleList(
          (0-7): 8 x AltBlock(
            (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (attn): AltAttention(
              (qkv): Linear(in_features=1024, out_features=3072, bias=True)
              (attn_drop): Dropout(p=0.1, inplace=False)
              (proj): Linear(in_features=1024, out_features=1024, bias=True)
              (proj_drop): Dropout(p=0.1, inplace=False)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=1024, out_features=4096, bias=True)
              (act): GELU(approximate='none')
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=4096, out_features=1024, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (post_mlp_dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=True)
      )
      (decoder): Decoder1d(
        (blocks): Sequential(
          (0): Sequential(
            (0): Conv1d(1024, 768, kernel_size=(7,), stride=(1,), padding=(3,), groups=16)
            (1): SamePad()
            (2): TransposeLast()
            (3): LayerNorm((768,), eps=1e-05, elementwise_affine=False)
            (4): TransposeLast()
            (5): GELU(approximate='none')
          )
          (1): Sequential(
            (0): Conv1d(768, 768, kernel_size=(7,), stride=(1,), padding=(3,), groups=16)
            (1): SamePad()
            (2): TransposeLast()
            (3): LayerNorm((768,), eps=1e-05, elementwise_affine=False)
            (4): TransposeLast()
            (5): GELU(approximate='none')
          )
          (2): Sequential(
            (0): Conv1d(768, 768, kernel_size=(7,), stride=(1,), padding=(3,), groups=16)
            (1): SamePad()
            (2): TransposeLast()
            (3): LayerNorm((768,), eps=1e-05, elementwise_affine=False)
            (4): TransposeLast()
            (5): GELU(approximate='none')
          )
          (3): Sequential(
            (0): Conv1d(768, 768, kernel_size=(7,), stride=(1,), padding=(3,), groups=16)
            (1): SamePad()
            (2): TransposeLast()
            (3): LayerNorm((768,), eps=1e-05, elementwise_affine=False)
            (4): TransposeLast()
            (5): GELU(approximate='none')
          )
        )
        (proj): Linear(in_features=768, out_features=1024, bias=True)
      )
    )
  )
  (dropout_input): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-15): 16 x AltBlock(
      (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): AltAttention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.1, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (post_mlp_dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

Implementation

References