(yet) GPU Programming (6/6) Triton Impl of Ring Attention
18 May 2024< 목차 >
tmp
tmp
from enum import Enum
from einops import rearrange
from triton_attn import flash_attn_func, ring_attn_func, striped_attn_func, step_attn_func
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class AttentionType(Enum):
SDPA = 0
FLASH = 1
RING = 2
STRIPED = 3
STEP = 4
class Attention(nn.Module):
def __init__(
self,
dim: int,
dim_per_head: int,
seq_len: int,
rank: int,
n_ranks: int,
dropout: float,
causal: bool,
attn_type: AttentionType = AttentionType.FLASH,
):
super().__init__()
assert isinstance(attn_type, AttentionType), attn_type
assert dim % dim_per_head == 0, (dim, dim_per_head)
assert seq_len % n_ranks == 0, (seq_len, n_ranks)
if attn_type in (AttentionType.STRIPED, AttentionType.STEP):
assert causal, causal
self.dim = dim
self.dim_per_head = dim_per_head
self.n_heads = dim // dim_per_head
self.seq_len = seq_len
self.seq_len_per_rank = seq_len // n_ranks
self.dropout = dropout
self.causal = causal
self.attn_type = attn_type
self.rank = rank
self.n_ranks = n_ranks
self.prev_rank = (self.rank - 1 + n_ranks) % n_ranks
self.next_rank = (self.rank + 1) % n_ranks
self.qkv_proj = nn.Linear(dim, 3 * dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
self.o_drop = nn.Dropout(p=dropout)
def shard(self, x: torch.Tensor) -> torch.Tensor:
if self.attn_type in (AttentionType.SDPA, AttentionType.FLASH):
return x
batch_size, seq_len, dim = x.shape
if self.attn_type == AttentionType.RING:
x_shard = torch.split(x, seq_len // self.n_ranks, dim=1)[self.rank].contiguous()
elif self.attn_type == AttentionType.STRIPED:
x_shard = x.view(batch_size, seq_len // self.n_ranks, self.n_ranks, dim)[:, :, self.rank]
elif self.attn_type == AttentionType.STEP:
x_batch_1, x_batch_2 = torch.split(x, batch_size // 2, dim=0)
x_shard = torch.cat((
torch.split(x_batch_1, seq_len // self.n_ranks, dim=1)[self.rank],
torch.split(x_batch_2, seq_len // self.n_ranks, dim=1)[self.n_ranks - self.rank - 1],
), dim=0)
else:
raise ValueError(f"Invalid {self.attn_type=}")
return x_shard
def unshard(self, x: torch.Tensor) -> torch.Tensor:
if self.attn_type in (AttentionType.SDPA, AttentionType.FLASH):
return x
batch_size, shard_seq_len, dim = x.shape
seq_len = self.n_ranks * shard_seq_len
all_x = [torch.zeros_like(x) for _ in range(self.n_ranks)]
dist.all_gather(all_x, x)
if self.attn_type == AttentionType.RING:
all_x = torch.cat(all_x, dim=1)
elif self.attn_type == AttentionType.STRIPED:
all_x = torch.stack(all_x, dim=2).view(batch_size, seq_len, dim)
elif self.attn_type == AttentionType.STEP:
all_x = torch.cat(all_x, dim=1)
x_batch_1, x_batch_2 = torch.split(all_x, batch_size // 2, dim=0)
batch_2_seq_idxs = torch.cat(torch.split(torch.arange(seq_len), seq_len // self.n_ranks, dim=0)[::-1], dim=0)
all_x = torch.cat((x_batch_1, x_batch_2[:, batch_2_seq_idxs]), dim=0)
else:
raise ValueError(f"Invalid {self.attn_type=}")
return all_x
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.qkv_proj(x)
qkv = rearrange(qkv, "b n (three h d) -> three b n h d", three=3, h=self.n_heads, d=self.dim_per_head)
q, k, v = torch.unbind(qkv, dim=0)
sm_scale = k.shape[-1] ** -0.5
if self.attn_type == AttentionType.SDPA:
# NOTE: this is just for testing purposes
o = F.scaled_dot_product_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), is_causal=self.causal).permute(0, 2, 1, 3)
elif self.attn_type == AttentionType.FLASH:
o = flash_attn_func(q, k, v, None, self.causal, sm_scale)
elif self.attn_type == AttentionType.RING:
o = ring_attn_func(q, k, v, self.n_ranks, self.rank, self.prev_rank, self.next_rank, None, self.causal, sm_scale)
elif self.attn_type == AttentionType.STRIPED:
o = striped_attn_func(q, k, v, self.n_ranks, self.rank, self.prev_rank, self.next_rank, None, self.causal, sm_scale)
elif self.attn_type == AttentionType.STEP:
o = step_attn_func(q, k, v, self.n_ranks, self.rank, self.prev_rank, self.next_rank, None, self.causal, sm_scale)
else:
raise ValueError(f"Invalid {self.attn_type=}")
o = rearrange(o, "b n h d -> b n (h d)")
o = self.o_proj(o)
return o
tmp
Benchmarking
Fig.