run on cpu

main
randaller 3 years ago committed by GitHub
parent e04d724890
commit 89deefc049
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,13 +9,6 @@ import torch
from torch import nn
import torch.nn.functional as F
import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.layers import (
ParallelEmbedding,
RowParallelLinear,
ColumnParallelLinear,
)
@dataclass
class ModelArgs:
@ -77,44 +70,40 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
# self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
self.wq = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
self.wk = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
self.wv = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
self.wo = nn.Linear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
).cpu()
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
).cpu()
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.shape
@ -129,8 +118,8 @@ class Attention(nn.Module):
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
@ -161,14 +150,14 @@ class FeedForward(nn.Module):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
self.w1 = nn.Linear(
dim, hidden_dim, bias=False,
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
self.w2 = nn.Linear(
hidden_dim, dim, bias=False,
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
self.w3 = nn.Linear(
dim, hidden_dim, bias=False,
)
def forward(self, x):
@ -202,8 +191,8 @@ class Transformer(nn.Module):
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
self.tok_embeddings = nn.Embedding(
params.vocab_size, params.dim
)
self.layers = torch.nn.ModuleList()
@ -211,8 +200,9 @@ class Transformer(nn.Module):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
self.output = nn.Linear(
params.dim, params.vocab_size, bias=False,
)
self.freqs_cis = precompute_freqs_cis(
@ -224,7 +214,7 @@ class Transformer(nn.Module):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
mask = None
if seqlen > 1:

Loading…
Cancel
Save