diff --git a/llama/__init__.py b/llama/__init__.py new file mode 100644 index 0000000..e2c5db3 --- /dev/null +++ b/llama/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from .generation import LLaMA +from .model import ModelArgs, Transformer +from .tokenizer import Tokenizer diff --git a/llama/generation.py b/llama/generation.py new file mode 100644 index 0000000..5ba68f5 --- /dev/null +++ b/llama/generation.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import List + +import torch +import traceback + +from llama.tokenizer import Tokenizer +from llama.model import Transformer +from tqdm import trange + + +class LLaMA: + def __init__(self, model: Transformer, tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer + + def generate( + self, + prompts: List[str], + max_gen_len: int, + temperature: float = 0.8, + top_p: float = 0.95, + ) -> List[str]: + bsz = len(prompts) + params = self.model.params + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + count_newlines = prompts[0].count("\n") + + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + + min_prompt_size = min([len(t) for t in prompt_tokens]) + max_prompt_size = max([len(t) for t in prompt_tokens]) + + total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) + + tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).long() + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t).long() + tokens[k, -1] = self.tokenizer.eos_id + input_text_mask = tokens != self.tokenizer.pad_id + start_pos = min_prompt_size + prev_pos = 0 + decoded = [None] * bsz + for cur_pos in trange(start_pos, total_len, desc="forward"): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + next_token = next_token.reshape(-1).cpu() + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + prev_pos = cur_pos + + print("-" * 30) + for i, t in enumerate(tokens.tolist()): + # i = cur_pos + # t = next_token + # cut to max gen len + # t = t[: len(pr-ompt_tokens[i]) + max_gen_len] + t = t[: min(cur_pos, len(prompt_tokens[i]) + max_gen_len)] + # cut to eos tok if any + try: + t = t[: t.index(self.tokenizer.eos_id)] + except ValueError: + pass # traceback.print_exc() + try: + d = self.tokenizer.decode(t) + print([i] * 20) + print(d) + decoded[i] = d + + result_count_newlines = d.count("\n") + if result_count_newlines > count_newlines: + return decoded + + except IndexError: + traceback.print_exc() + print(t) + print("-" * 30) + return decoded + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/llama/model.py b/llama/model.py new file mode 100644 index 0000000..a7c2cee --- /dev/null +++ b/llama/model.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Optional, Tuple +from dataclasses import dataclass +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init + +from tqdm import tqdm + +@dataclass +class ModelArgs: + dim: int = 512 + n_layers: int = 8 + n_heads: int = 8 + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 1024 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.n_local_heads = args.n_heads # // fs_init.get_model_parallel_world_size() + self.head_dim = args.dim // args.n_heads + + self.wq = skip_init(nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wk = skip_init(nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wv = skip_init(nn.Linear, + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wo = skip_init(nn.Linear, + args.n_heads * self.head_dim, + args.dim, + bias=False, + ) + + self.cache_k = torch.zeros( + (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) + ).cuda() + self.cache_v = torch.zeros( + (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) + ).cuda() + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + output = output.transpose( + 1, 2 + ).contiguous().view(bsz, seqlen, -1) + + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = skip_init(nn.Linear, + dim, + hidden_dim, + bias=False, + ) + self.w2 = skip_init(nn.Linear, + hidden_dim, + dim, + bias=False, + ) + self.w3 = skip_init(nn.Linear, + dim, + hidden_dim, + bias=False, + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + +# https://github.com/gmorenz/llama/commit/4daf7f1a2f2bb22208b5d464bc2a18511d54408d +def move_parameters_to_gpu(module): + if not hasattr(module, "saved"): + module.saved = module._parameters.copy() + for k, param in module.saved.items(): + if param is not None: + module._parameters[k] = param.to("cuda", non_blocking=True) + for child in module.children(): + move_parameters_to_gpu(child) + +def move_parameters_to_cpu(module): + for k, param in module.saved.items(): + del module._parameters[k] + module._parameters[k] = param + for child in module.children(): + move_parameters_to_cpu(child) + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = skip_init(nn.Embedding, + params.vocab_size, + params.dim, + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.layer_locations = [None] * len(self.layers) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps).cuda() + self.output = skip_init(nn.Linear, + params.dim, + params.vocab_size, + bias=False, + ).cuda() + + self.freqs_cis = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ).cuda() + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + use_gpu = True # start_pos == 0 + + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + if use_gpu: + h = h.cuda() + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + if use_gpu and mask is not None: + mask = mask.cuda() + + for layer in tqdm(self.layers, desc="flayers", leave=True): + if use_gpu: + move_parameters_to_gpu(layer) + h = layer(h, start_pos, freqs_cis, mask) + if use_gpu: + move_parameters_to_cpu(layer) + + h = self.norm(h) + if use_gpu: + del mask + torch.cuda.empty_cache() + output = self.output(h[:, -1, :]) # only compute last logits + return output.float() diff --git a/llama/tokenizer.py b/llama/tokenizer.py new file mode 100644 index 0000000..5117505 --- /dev/null +++ b/llama/tokenizer.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from sentencepiece import SentencePieceProcessor +from logging import getLogger +from typing import List +import os + + +logger = getLogger() + + +class Tokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t)