no message

main
BlinkDL 4 years ago
parent 0b6aec3da6
commit 5817d265c3

Binary file not shown.

@ -6,6 +6,7 @@
import numpy as np
import types
import copy
import time
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER
@ -20,8 +21,9 @@ n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
MODEL_NAME = 'trained-31' # your trained model
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
# your trained model
MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
WORD_NAME = 'enwik8-vocab' # the .json vocab (generated by train.py)
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> unknown tokens in your context will be denoted by it <--
@ -44,12 +46,13 @@ top_p_newline = 0.9
########################################################################################################
np.set_printoptions(precision=4, suppress=True, linewidth=200)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
context = tokenizer.refine_context(context)
print('Your context has ' + str(len(context)) + ' tokens')
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
print(f'Loading {MODEL_NAME}...')
##############################################################################################################
@ -65,7 +68,7 @@ class RWKV_RNN():
self.w = types.SimpleNamespace()
w = torch.load(MODEL_NAME + '.pth',
map_location=torch.device(RUN_DEVICE)) # .state_dict()
map_location=torch.device(RUN_DEVICE))
for x in w.keys():
if '.time_' in x:
w[x] = w[x].squeeze()
@ -195,12 +198,12 @@ class RWKV_RNN():
model = RWKV_RNN(MODEL_NAME)
print('\n\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
print(context.replace('\n', '\n '), end='')
print(('-' * 30) + context, end='')
model.clear()
if TRIAL == 0:
@ -230,7 +233,7 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
print(tokenizer.itos[int(char)].replace(
'\n', '\n '), end='', flush=True)
print(tokenizer.itos[int(char)], end='', flush=True)
ctx += [char]
print('\n' + '-' * 40, end='')
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -137,5 +137,5 @@ if __name__ == '__main__':
trainer.train()
torch.save(model, 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')

Loading…
Cancel
Save