######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import numpy as np import math, os import time import types import copy import torch from src.utils import TOKENIZER torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True np.set_printoptions(precision=4, suppress=True, linewidth=200) ######################################################################################################## # Step 1: set model # # Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch. # # Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models. ######################################################################################################## TOKEN_MODE = "pile" # char / bpe / pile n_layer = 6 n_embd = 512 ctx_len = 1024 if TOKEN_MODE == "char": MODEL_NAME = "trained-500" # your trained model WORD_NAME = "vocab" # the .json vocab (generated by train.py) # set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity elif TOKEN_MODE == "bpe": MODEL_NAME = "trained-500" # your trained model WORD_NAME = [ "model-vocab.json", "model-merges.txt", ] # [vocab, merge] for your BPE model UNKNOWN_CHAR = None elif TOKEN_MODE == "pile": WORD_NAME = [ "20B_tokenizer.json", "20B_tokenizer.json", ] # [vocab, vocab] for Pile model UNKNOWN_CHAR = None # ---> you can set MODEL_NAME to your fine-tuned model <--- # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" # n_layer = 12 # n_embd = 768 # ctx_len = 1024 # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066' # n_layer = 24 # n_embd = 1024 # ctx_len = 1024 MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040' n_layer = 24 n_embd = 2048 ctx_len = 1024 os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' ######################################################################################################## # Step 2: set prompt & sampling stuffs ######################################################################################################## # context = 'A' # context = "\nIn the" # context = '\nSugar:' context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." NUM_TRIALS = 999 LENGTH_PER_TRIAL = 333 TEMPERATURE = 1.0 top_p = 0.8 top_p_newline = 0.9 # only used in TOKEN_MODE = char DEBUG_DEBUG = False # True False --> show softmax output ######################################################################################################## print(f"Loading {MODEL_NAME}...") from src.model_run import RWKV_RNN model = RWKV_RNN( MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len ) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) ######################################################################################################## if tokenizer.charMode: context = tokenizer.refine_context(context) ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] else: ctx = tokenizer.tokenizer.encode(context) src_len = len(ctx) src_ctx = ctx.copy() print("\nYour prompt has " + str(src_len) + " tokens.") print( "\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n" ) # time_slot = {} # time_ref = time.time_ns() # def record_time(name): # if name not in time_slot: # time_slot[name] = 1e20 # tt = (time.time_ns() - time_ref) / 1e9 # if tt < time_slot[name]: # time_slot[name] = tt for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): # time_ref = time.time_ns() print(("-" * 50) + context, end="") ctx = src_ctx.copy() model.clear() if TRIAL == 0: init_state = types.SimpleNamespace() for i in range(src_len): x = ctx[: i + 1] if i == src_len - 1: init_state.out = model.forward(x) else: model.forward(x, preprocess_only=True) model.save(init_state) else: model.load(init_state) # record_time('model_pre') for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): # time_ref = time.time_ns() x = ctx[: i + 1] x = x[-ctx_len:] if i == src_len: out = copy.deepcopy(init_state.out) else: out = model.forward(x) # record_time('model_run') if DEBUG_DEBUG: print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) if TOKEN_MODE == "pile": out[0] = -999999999 # disable <|endoftext|> time_ref = time.time_ns() char = tokenizer.sample_logits( out, x, ctx_len, temperature=TEMPERATURE, top_p_usual=top_p, top_p_newline=top_p_newline, ) if tokenizer.charMode: print(tokenizer.itos[char], end="", flush=True) else: print(tokenizer.tokenizer.decode(char), end="", flush=True) ctx += [char] # record_time('model_sampling') print() # print(f'\n\n{time_slot}\n\n') # print( # f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = '' # )