# -*- coding:utf-8 -*- ######################################################################################################## # The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import numpy as np import types import copy import torch from torch.nn import functional as F from src.utils import TOKENIZER torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True ### Step 1: set model ################################################################################## ctx_len = 1024 n_layer = 6 n_embd = 512 model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' MODEL_NAME = 'trained-31' # your trained model WORD_NAME = 'vocab' # the .json vocab (generated by train.py) # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- # --> unknown tokens in your context will be denoted by it <-- UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' DEBUG_DEBUG = False # True False - show softmax output DEBUG_TIME = False # True False - show trained time-coeffs ### Step 2: set context ################################################################################ context = "\n" # ==> this is your prompt NUM_TRIALS = 999 LENGTH_PER_TRIAL = 500 TEMPERATURE = 1.0 top_p = 0.7 top_p_newline = 0.9 ######################################################################################################## np.set_printoptions(precision=4, suppress=True, linewidth=200) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) context = tokenizer.refine_context(context) print('Your context has ' + str(len(context)) + ' tokens') print(f'Loading {MODEL_NAME}...') ############################################################################################################## RWKV_K_CLAMP = 60 RWKV_K_EPS = 1e-16 RWKV_HEAD_QK_DIM = 256 class RWKV_RNN(): def __init__(self, MODEL_NAME): self.w = types.SimpleNamespace() w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE)) # .state_dict() for x in w.keys(): if '.time_' in x: w[x] = w[x].squeeze() if '.time_decay' in x: w[x] = torch.exp(-torch.exp(w[x])) if '.time_first' in x: w[x] = torch.exp(w[x]) xx = x.split('.') here = self.w for i in range(len(xx)): if xx[i].isdigit(): ii = int(xx[i]) if ii not in here: here[ii] = types.SimpleNamespace() here = here[ii] else: if i == len(xx) - 1: setattr(here, xx[i], w[x]) elif not hasattr(here, xx[i]): if xx[i+1].isdigit(): setattr(here, xx[i], {}) else: setattr(here, xx[i], types.SimpleNamespace()) here = getattr(here, xx[i]) self.clear() def clear(self): self.xx = {} self.aa = {} self.bb = {} self.hk = None def save(self, target): target.xx = copy.deepcopy(self.xx) target.aa = copy.deepcopy(self.aa) target.bb = copy.deepcopy(self.bb) target.hk = copy.deepcopy(self.hk) def load(self, target): self.xx = copy.deepcopy(target.xx) self.aa = copy.deepcopy(target.aa) self.bb = copy.deepcopy(target.bb) self.hk = copy.deepcopy(target.hk) def LN(self, xx, w): return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias) def FF(self, xx, w, name): if DEBUG_TIME: print(name+'.time_mix', w.time_mix.squeeze().numpy()) if name not in self.xx: self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE) x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix) self.xx[name] = xx r = torch.sigmoid(w.receptance.weight @ x) k = torch.square(torch.relu(w.key.weight @ x)) kv = w.value.weight @ k return r * kv def SA(self, xx, w, name): if DEBUG_TIME: print(name+'.time_mix', w.time_mix.squeeze().numpy()) print(name+'.time_decay', w.time_decay.squeeze().numpy()) print(name+'.time_first', w.time_first.squeeze().numpy()) if name not in self.xx: self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE) self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE) self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE) x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix) self.xx[name] = xx r = torch.sigmoid(w.receptance.weight @ x) k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP)) v = w.value.weight @ x kv = k * v a = self.aa[name] + w.time_first * kv b = self.bb[name] + w.time_first * k self.aa[name] = w.time_decay * self.aa[name] + kv self.bb[name] = w.time_decay * self.bb[name] + k rwkv = r * a / (b + RWKV_K_EPS) return w.output.weight @ rwkv def run(self, ctx): w = self.w x = w.emb.weight[ctx[-1]] for i in range(n_layer): x = self.LN(x, w.blocks[i].ln1) if i == 0 and model_type == 'RWKV-ffnPre': x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}') else: x = x + self.SA(x, w.blocks[i].att, f'att.{i}') x = self.LN(x, w.blocks[i].ln2) x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}') x = self.LN(x, w.ln_out) if self.hk == None: self.hk = (w.head_k.weight @ x).unsqueeze(0) else: self.hk = torch.cat( [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) if self.hk.shape[0] > ctx_len: self.hk = self.hk[-ctx_len:, :] q = w.head_q.weight @ x x = w.head.weight @ x x = x.cpu().numpy().tolist() c = (self.hk @ q) / RWKV_HEAD_QK_DIM for i in range(len(c)): x[ctx[i]] += c[i] return x ############################################################################################################## model = RWKV_RNN(MODEL_NAME) print('\n\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): src_len = len(context) ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] print(context.replace('\n', '\n '), end='') model.clear() if TRIAL == 0: init_state = types.SimpleNamespace() for i in range(src_len): x = ctx[:i+1] if i == src_len - 1: init_state.out = model.run(x) else: model.run(x) model.save(init_state) else: model.load(init_state) for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): x = ctx[:i+1] x = x[-ctx_len:] if i == src_len: out = copy.deepcopy(init_state.out) else: out = model.run(x) if DEBUG_DEBUG: print('model', np.array(x), '==>', np.array( out), np.max(out), np.min(out)) char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_p_usual=top_p, top_p_newline=top_p_newline) char = char.item() print(tokenizer.itos[int(char)].replace( '\n', '\n '), end='', flush=True) ctx += [char] print('\n' + '-' * 40, end='')