# -*- coding:utf-8 -*- ######################################################################################################## # The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import numpy as np import math import time import types import copy import torch from torch.nn import functional as F from src.utils import TOKENIZER, Dataset from src.model_run import RWKV_RNN torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True np.set_printoptions(precision=4, suppress=True, linewidth=200) ### Step 1: set model ################################################################################## ctx_len = 1024 n_layer = 6 n_embd = 512 model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' # your trained model MODEL_NAME = 'trained-31' WORD_NAME = 'vocab' # the .json vocab (generated by train.py # ########## Uncomment these to test my 27M params enwik8 model ########## # MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13' # WORD_NAME = 'enwik8-vocab' # EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation) # ######################################################################## # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- # --> all unknown tokens in your context will be denoted by it <-- UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' DEBUG_DEBUG = False # True False - show softmax output ### Step 2: set context ################################################################################ context = "\nIn the" # ==> this is your prompt NUM_TRIALS = 999 LENGTH_PER_TRIAL = 500 TEMPERATURE = 1.0 top_p = 0.7 top_p_newline = 0.9 ######################################################################################################## print(f'Loading {MODEL_NAME}...') model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) ######################################################################################################## if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals(): print('Evaluating on ' + EVAL_DATA + ' ...') data = open(EVAL_DATA, "r", encoding='utf-8').read() loss_table = np.zeros(ctx_len) N_SAMPLE = 1000 for iii in range(N_SAMPLE): pos = np.random.randint(0, len(data) - ctx_len-1) context = data[pos:pos+ctx_len+1] ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] model.clear() for i in range(1, ctx_len+1): x = ctx[:i] out = model.run(x) prob = F.softmax(torch.tensor(out), dim=-1) loss_table[i-1] += -math.log(prob[ctx[i]]) print(f'Tested {iii+1} samples: avg_loss over ctx_len =', np.mean(loss_table) / (iii+1)) exit(0) ######################################################################################################## context = tokenizer.refine_context(context) print('\nYour prompt has ' + str(len(context)) + ' tokens.') print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): t_begin = time.time_ns() src_len = len(context) ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] print(('-' * 30) + context, end='') model.clear() if TRIAL == 0: init_state = types.SimpleNamespace() for i in range(src_len): x = ctx[:i+1] if i == src_len - 1: init_state.out = model.run(x) else: model.run(x) model.save(init_state) else: model.load(init_state) for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): x = ctx[:i+1] x = x[-ctx_len:] if i == src_len: out = copy.deepcopy(init_state.out) else: out = model.run(x) if DEBUG_DEBUG: print('model', np.array(x), '==>', np.array( out), np.max(out), np.min(out)) char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_p_usual=top_p, top_p_newline=top_p_newline) char = char.item() print(tokenizer.itos[int(char)], end='', flush=True) ctx += [char] t_end = time.time_ns() print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')