|
|
|
|
@ -9,13 +9,6 @@ import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
import fairscale.nn.model_parallel.initialize as fs_init
|
|
|
|
|
from fairscale.nn.model_parallel.layers import (
|
|
|
|
|
ParallelEmbedding,
|
|
|
|
|
RowParallelLinear,
|
|
|
|
|
ColumnParallelLinear,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class ModelArgs:
|
|
|
|
|
@ -61,9 +54,9 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_rotary_emb(
|
|
|
|
|
xq: torch.Tensor,
|
|
|
|
|
xk: torch.Tensor,
|
|
|
|
|
freqs_cis: torch.Tensor,
|
|
|
|
|
xq: torch.Tensor,
|
|
|
|
|
xk: torch.Tensor,
|
|
|
|
|
freqs_cis: torch.Tensor,
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
|
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
|
|
|
@ -77,44 +70,40 @@ class Attention(nn.Module):
|
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
|
|
|
|
|
# self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
|
|
|
|
|
self.n_local_heads = args.n_heads
|
|
|
|
|
self.head_dim = args.dim // args.n_heads
|
|
|
|
|
|
|
|
|
|
self.wq = ColumnParallelLinear(
|
|
|
|
|
self.wq = nn.Linear(
|
|
|
|
|
args.dim,
|
|
|
|
|
args.n_heads * self.head_dim,
|
|
|
|
|
bias=False,
|
|
|
|
|
gather_output=False,
|
|
|
|
|
init_method=lambda x: x,
|
|
|
|
|
)
|
|
|
|
|
self.wk = ColumnParallelLinear(
|
|
|
|
|
|
|
|
|
|
self.wk = nn.Linear(
|
|
|
|
|
args.dim,
|
|
|
|
|
args.n_heads * self.head_dim,
|
|
|
|
|
bias=False,
|
|
|
|
|
gather_output=False,
|
|
|
|
|
init_method=lambda x: x,
|
|
|
|
|
)
|
|
|
|
|
self.wv = ColumnParallelLinear(
|
|
|
|
|
|
|
|
|
|
self.wv = nn.Linear(
|
|
|
|
|
args.dim,
|
|
|
|
|
args.n_heads * self.head_dim,
|
|
|
|
|
bias=False,
|
|
|
|
|
gather_output=False,
|
|
|
|
|
init_method=lambda x: x,
|
|
|
|
|
)
|
|
|
|
|
self.wo = RowParallelLinear(
|
|
|
|
|
|
|
|
|
|
self.wo = nn.Linear(
|
|
|
|
|
args.n_heads * self.head_dim,
|
|
|
|
|
args.dim,
|
|
|
|
|
bias=False,
|
|
|
|
|
input_is_parallel=True,
|
|
|
|
|
init_method=lambda x: x,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.cache_k = torch.zeros(
|
|
|
|
|
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
|
|
|
|
|
).cuda()
|
|
|
|
|
).cpu()
|
|
|
|
|
self.cache_v = torch.zeros(
|
|
|
|
|
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
|
|
|
|
|
).cuda()
|
|
|
|
|
).cpu()
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
|
|
|
|
bsz, seqlen, _ = x.shape
|
|
|
|
|
@ -129,8 +118,8 @@ class Attention(nn.Module):
|
|
|
|
|
self.cache_k = self.cache_k.to(xq)
|
|
|
|
|
self.cache_v = self.cache_v.to(xq)
|
|
|
|
|
|
|
|
|
|
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
|
|
|
|
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
|
|
|
|
self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
|
|
|
|
|
self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
|
|
|
|
|
|
|
|
|
|
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
|
|
|
|
values = self.cache_v[:bsz, : start_pos + seqlen]
|
|
|
|
|
@ -152,23 +141,23 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim: int,
|
|
|
|
|
hidden_dim: int,
|
|
|
|
|
multiple_of: int,
|
|
|
|
|
self,
|
|
|
|
|
dim: int,
|
|
|
|
|
hidden_dim: int,
|
|
|
|
|
multiple_of: int,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
hidden_dim = int(2 * hidden_dim / 3)
|
|
|
|
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
|
|
|
|
|
|
|
|
self.w1 = ColumnParallelLinear(
|
|
|
|
|
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
|
|
|
|
self.w1 = nn.Linear(
|
|
|
|
|
dim, hidden_dim, bias=False,
|
|
|
|
|
)
|
|
|
|
|
self.w2 = RowParallelLinear(
|
|
|
|
|
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
|
|
|
|
|
self.w2 = nn.Linear(
|
|
|
|
|
hidden_dim, dim, bias=False,
|
|
|
|
|
)
|
|
|
|
|
self.w3 = ColumnParallelLinear(
|
|
|
|
|
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
|
|
|
|
self.w3 = nn.Linear(
|
|
|
|
|
dim, hidden_dim, bias=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
@ -202,8 +191,8 @@ class Transformer(nn.Module):
|
|
|
|
|
self.vocab_size = params.vocab_size
|
|
|
|
|
self.n_layers = params.n_layers
|
|
|
|
|
|
|
|
|
|
self.tok_embeddings = ParallelEmbedding(
|
|
|
|
|
params.vocab_size, params.dim, init_method=lambda x: x
|
|
|
|
|
self.tok_embeddings = nn.Embedding(
|
|
|
|
|
params.vocab_size, params.dim
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.layers = torch.nn.ModuleList()
|
|
|
|
|
@ -211,8 +200,9 @@ class Transformer(nn.Module):
|
|
|
|
|
self.layers.append(TransformerBlock(layer_id, params))
|
|
|
|
|
|
|
|
|
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
|
|
|
|
self.output = ColumnParallelLinear(
|
|
|
|
|
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(
|
|
|
|
|
params.dim, params.vocab_size, bias=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.freqs_cis = precompute_freqs_cis(
|
|
|
|
|
@ -224,7 +214,7 @@ class Transformer(nn.Module):
|
|
|
|
|
_bsz, seqlen = tokens.shape
|
|
|
|
|
h = self.tok_embeddings(tokens)
|
|
|
|
|
self.freqs_cis = self.freqs_cis.to(h.device)
|
|
|
|
|
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
|
|
|
|
freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
|
|
|
|
|
|
|
|
|
|
mask = None
|
|
|
|
|
if seqlen > 1:
|
|
|
|
|
|