Add files via upload
parent
e0a601ac8e
commit
5844747a1b
@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
|
||||||
|
|
||||||
|
from .generation import LLaMA
|
||||||
|
from .model import ModelArgs, Transformer
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
@ -0,0 +1,99 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from llama.tokenizer import Tokenizer
|
||||||
|
from llama.model import Transformer
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
|
||||||
|
class LLaMA:
|
||||||
|
def __init__(self, model: Transformer, tokenizer: Tokenizer):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_gen_len: int,
|
||||||
|
temperature: float = 0.8,
|
||||||
|
top_p: float = 0.95,
|
||||||
|
) -> List[str]:
|
||||||
|
bsz = len(prompts)
|
||||||
|
params = self.model.params
|
||||||
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
|
count_newlines = prompts[0].count("\n")
|
||||||
|
|
||||||
|
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
|
||||||
|
|
||||||
|
min_prompt_size = min([len(t) for t in prompt_tokens])
|
||||||
|
max_prompt_size = max([len(t) for t in prompt_tokens])
|
||||||
|
|
||||||
|
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
|
||||||
|
|
||||||
|
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).long()
|
||||||
|
for k, t in enumerate(prompt_tokens):
|
||||||
|
tokens[k, : len(t)] = torch.tensor(t).long()
|
||||||
|
tokens[k, -1] = self.tokenizer.eos_id
|
||||||
|
input_text_mask = tokens != self.tokenizer.pad_id
|
||||||
|
start_pos = min_prompt_size
|
||||||
|
prev_pos = 0
|
||||||
|
decoded = [None] * bsz
|
||||||
|
for cur_pos in trange(start_pos, total_len, desc="forward"):
|
||||||
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
if temperature > 0:
|
||||||
|
probs = torch.softmax(logits / temperature, dim=-1)
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits, dim=-1)
|
||||||
|
next_token = next_token.reshape(-1).cpu()
|
||||||
|
# only replace token if prompt has already been generated
|
||||||
|
next_token = torch.where(
|
||||||
|
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||||
|
)
|
||||||
|
tokens[:, cur_pos] = next_token
|
||||||
|
prev_pos = cur_pos
|
||||||
|
|
||||||
|
print("-" * 30)
|
||||||
|
for i, t in enumerate(tokens.tolist()):
|
||||||
|
# i = cur_pos
|
||||||
|
# t = next_token
|
||||||
|
# cut to max gen len
|
||||||
|
# t = t[: len(pr-ompt_tokens[i]) + max_gen_len]
|
||||||
|
t = t[: min(cur_pos, len(prompt_tokens[i]) + max_gen_len)]
|
||||||
|
# cut to eos tok if any
|
||||||
|
try:
|
||||||
|
t = t[: t.index(self.tokenizer.eos_id)]
|
||||||
|
except ValueError:
|
||||||
|
pass # traceback.print_exc()
|
||||||
|
try:
|
||||||
|
d = self.tokenizer.decode(t)
|
||||||
|
print([i] * 20)
|
||||||
|
print(d)
|
||||||
|
decoded[i] = d
|
||||||
|
|
||||||
|
result_count_newlines = d.count("\n")
|
||||||
|
if result_count_newlines > count_newlines:
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
except IndexError:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(t)
|
||||||
|
print("-" * 30)
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs, p):
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
||||||
@ -0,0 +1,270 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import skip_init
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
dim: int = 512
|
||||||
|
n_layers: int = 8
|
||||||
|
n_heads: int = 8
|
||||||
|
vocab_size: int = -1 # defined later by tokenizer
|
||||||
|
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
|
||||||
|
max_batch_size: int = 32
|
||||||
|
max_seq_len: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
return output * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||||
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
xq: torch.Tensor,
|
||||||
|
xk: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_local_heads = args.n_heads # // fs_init.get_model_parallel_world_size()
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
|
||||||
|
self.wq = skip_init(nn.Linear,
|
||||||
|
args.dim,
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.wk = skip_init(nn.Linear,
|
||||||
|
args.dim,
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.wv = skip_init(nn.Linear,
|
||||||
|
args.dim,
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.wo = skip_init(nn.Linear,
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
args.dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache_k = torch.zeros(
|
||||||
|
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
|
||||||
|
).cuda()
|
||||||
|
self.cache_v = torch.zeros(
|
||||||
|
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||||
|
bsz, seqlen, _ = x.shape
|
||||||
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
|
||||||
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
|
|
||||||
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
|
self.cache_k = self.cache_k.to(xq)
|
||||||
|
self.cache_v = self.cache_v.to(xq)
|
||||||
|
|
||||||
|
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||||
|
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||||
|
|
||||||
|
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||||
|
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||||
|
|
||||||
|
xq = xq.transpose(1, 2)
|
||||||
|
keys = keys.transpose(1, 2)
|
||||||
|
values = values.transpose(1, 2)
|
||||||
|
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
|
||||||
|
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||||
|
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
|
||||||
|
output = output.transpose(
|
||||||
|
1, 2
|
||||||
|
).contiguous().view(bsz, seqlen, -1)
|
||||||
|
|
||||||
|
return self.wo(output)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.w1 = skip_init(nn.Linear,
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.w2 = skip_init(nn.Linear,
|
||||||
|
hidden_dim,
|
||||||
|
dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.w3 = skip_init(nn.Linear,
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id: int, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = args.n_heads
|
||||||
|
self.dim = args.dim
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = FeedForward(
|
||||||
|
dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
|
||||||
|
)
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||||
|
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
|
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
||||||
|
return out
|
||||||
|
|
||||||
|
# https://github.com/gmorenz/llama/commit/4daf7f1a2f2bb22208b5d464bc2a18511d54408d
|
||||||
|
def move_parameters_to_gpu(module):
|
||||||
|
if not hasattr(module, "saved"):
|
||||||
|
module.saved = module._parameters.copy()
|
||||||
|
for k, param in module.saved.items():
|
||||||
|
if param is not None:
|
||||||
|
module._parameters[k] = param.to("cuda", non_blocking=True)
|
||||||
|
for child in module.children():
|
||||||
|
move_parameters_to_gpu(child)
|
||||||
|
|
||||||
|
def move_parameters_to_cpu(module):
|
||||||
|
for k, param in module.saved.items():
|
||||||
|
del module._parameters[k]
|
||||||
|
module._parameters[k] = param
|
||||||
|
for child in module.children():
|
||||||
|
move_parameters_to_cpu(child)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, params: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.params = params
|
||||||
|
self.vocab_size = params.vocab_size
|
||||||
|
self.n_layers = params.n_layers
|
||||||
|
|
||||||
|
self.tok_embeddings = skip_init(nn.Embedding,
|
||||||
|
params.vocab_size,
|
||||||
|
params.dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for layer_id in range(params.n_layers):
|
||||||
|
self.layers.append(TransformerBlock(layer_id, params))
|
||||||
|
|
||||||
|
self.layer_locations = [None] * len(self.layers)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps).cuda()
|
||||||
|
self.output = skip_init(nn.Linear,
|
||||||
|
params.dim,
|
||||||
|
params.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
self.freqs_cis = precompute_freqs_cis(
|
||||||
|
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||||
|
use_gpu = True # start_pos == 0
|
||||||
|
|
||||||
|
_bsz, seqlen = tokens.shape
|
||||||
|
h = self.tok_embeddings(tokens)
|
||||||
|
self.freqs_cis = self.freqs_cis
|
||||||
|
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||||
|
if use_gpu:
|
||||||
|
h = h.cuda()
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if seqlen > 1:
|
||||||
|
mask = torch.full(
|
||||||
|
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
|
||||||
|
)
|
||||||
|
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
||||||
|
|
||||||
|
if use_gpu and mask is not None:
|
||||||
|
mask = mask.cuda()
|
||||||
|
|
||||||
|
for layer in tqdm(self.layers, desc="flayers", leave=True):
|
||||||
|
if use_gpu:
|
||||||
|
move_parameters_to_gpu(layer)
|
||||||
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
|
if use_gpu:
|
||||||
|
move_parameters_to_cpu(layer)
|
||||||
|
|
||||||
|
h = self.norm(h)
|
||||||
|
if use_gpu:
|
||||||
|
del mask
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
output = self.output(h[:, -1, :]) # only compute last logits
|
||||||
|
return output.float()
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
|
||||||
|
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import List
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
# reload tokenizer
|
||||||
|
assert os.path.isfile(model_path), model_path
|
||||||
|
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||||
|
logger.info(f"Reloaded SentencePiece model from {model_path}")
|
||||||
|
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.n_words: int = self.sp_model.vocab_size()
|
||||||
|
self.bos_id: int = self.sp_model.bos_id()
|
||||||
|
self.eos_id: int = self.sp_model.eos_id()
|
||||||
|
self.pad_id: int = self.sp_model.pad_id()
|
||||||
|
logger.info(
|
||||||
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||||
|
)
|
||||||
|
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||||
|
|
||||||
|
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
||||||
|
assert type(s) is str
|
||||||
|
t = self.sp_model.encode(s)
|
||||||
|
if bos:
|
||||||
|
t = [self.bos_id] + t
|
||||||
|
if eos:
|
||||||
|
t = t + [self.eos_id]
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, t: List[int]) -> str:
|
||||||
|
return self.sp_model.decode(t)
|
||||||
Loading…
Reference in New Issue