diff --git a/RWKV-v2-RNN/enwik8-vocab.json b/RWKV-v2-RNN/enwik8-vocab.json new file mode 100644 index 0000000..181cd09 Binary files /dev/null and b/RWKV-v2-RNN/enwik8-vocab.json differ diff --git a/RWKV-v2-RNN/run.py b/RWKV-v2-RNN/run.py index 6a1c462..a9bfb59 100644 --- a/RWKV-v2-RNN/run.py +++ b/RWKV-v2-RNN/run.py @@ -6,12 +6,13 @@ import numpy as np import types import copy +import time import torch from torch.nn import functional as F from src.utils import TOKENIZER torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True ### Step 1: set model ################################################################################## @@ -20,8 +21,9 @@ n_layer = 6 n_embd = 512 model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' -MODEL_NAME = 'trained-31' # your trained model -WORD_NAME = 'vocab' # the .json vocab (generated by train.py) +# your trained model +MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13' +WORD_NAME = 'enwik8-vocab' # the .json vocab (generated by train.py) # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- # --> unknown tokens in your context will be denoted by it <-- @@ -44,12 +46,13 @@ top_p_newline = 0.9 ######################################################################################################## - np.set_printoptions(precision=4, suppress=True, linewidth=200) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) context = tokenizer.refine_context(context) -print('Your context has ' + str(len(context)) + ' tokens') +print('\nYour prompt has ' + str(len(context)) + ' tokens.') +print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') + print(f'Loading {MODEL_NAME}...') ############################################################################################################## @@ -65,7 +68,7 @@ class RWKV_RNN(): self.w = types.SimpleNamespace() w = torch.load(MODEL_NAME + '.pth', - map_location=torch.device(RUN_DEVICE)) # .state_dict() + map_location=torch.device(RUN_DEVICE)) for x in w.keys(): if '.time_' in x: w[x] = w[x].squeeze() @@ -195,12 +198,12 @@ class RWKV_RNN(): model = RWKV_RNN(MODEL_NAME) -print('\n\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') - for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): + t_begin = time.time_ns() + src_len = len(context) ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] - print(context.replace('\n', '\n '), end='') + print(('-' * 30) + context, end='') model.clear() if TRIAL == 0: @@ -230,7 +233,7 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_p_usual=top_p, top_p_newline=top_p_newline) char = char.item() - print(tokenizer.itos[int(char)].replace( - '\n', '\n '), end='', flush=True) + print(tokenizer.itos[int(char)], end='', flush=True) ctx += [char] - print('\n' + '-' * 40, end='') + t_end = time.time_ns() + print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') diff --git a/RWKV-v2-RNN/train.py b/RWKV-v2-RNN/train.py index 64291a8..0644ecd 100644 --- a/RWKV-v2-RNN/train.py +++ b/RWKV-v2-RNN/train.py @@ -137,5 +137,5 @@ if __name__ == '__main__': trainer.train() - torch.save(model, 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() + + torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')