+eval code for 27M ppl 1.65 BPC 0.72 enwik8 model

main 2.00
BlinkDL 4 years ago
parent 71538e44a9
commit 88e921bf10

@ -4,16 +4,18 @@
######################################################################################################## ########################################################################################################
import numpy as np import numpy as np
import math
import time import time
import types import types
import copy import copy
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from src.utils import TOKENIZER from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ################################################################################## ### Step 1: set model ##################################################################################
@ -26,9 +28,11 @@ model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
MODEL_NAME = 'trained-31' MODEL_NAME = 'trained-31'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py WORD_NAME = 'vocab' # the .json vocab (generated by train.py
# ### uncompress enwik8-model.zip to test my enwik8 model # ########## Uncomment these to test my 27M params enwik8 model ##########
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13' # MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
# WORD_NAME = 'enwik8-vocab' # WORD_NAME = 'enwik8-vocab'
# EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation)
# ########################################################################
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <-- # --> all unknown tokens in your context will be denoted by it <--
@ -50,16 +54,44 @@ top_p_newline = 0.9
######################################################################################################## ########################################################################################################
np.set_printoptions(precision=4, suppress=True, linewidth=200) print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
########################################################################################################
if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
print('Evaluating on ' + EVAL_DATA + ' ...')
data = open(EVAL_DATA, "r", encoding='utf-8').read()
loss_table = np.zeros(ctx_len)
N_SAMPLE = 1000
for iii in range(N_SAMPLE):
pos = np.random.randint(0, len(data) - ctx_len-1)
context = data[pos:pos+ctx_len+1]
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
model.clear()
for i in range(1, ctx_len+1):
x = ctx[:i]
out = model.run(x)
prob = F.softmax(torch.tensor(out), dim=-1)
loss_table[i-1] += -math.log(prob[ctx[i]])
print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
np.mean(loss_table) / (iii+1))
exit(0)
########################################################################################################
context = tokenizer.refine_context(context) context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.') print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns() t_begin = time.time_ns()

@ -151,7 +151,7 @@ class Trainer:
self.avg_loss = self.avg_loss * \ self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor (1.0 - factor) + now_loss * factor
pbar.set_description( pbar.set_description(
f"epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
self.tokens = 0 # counter used for learning rate decay self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs): for epoch in range(config.max_epochs):

@ -10,6 +10,48 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
class TOKENIZER(): class TOKENIZER():

@ -7,7 +7,7 @@ import datetime
import json import json
from src.model import GPT, GPTConfig from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig from src.trainer import Trainer, TrainerConfig
from torch.utils.data import Dataset from src.utils import Dataset
import torch import torch
import numpy as np import numpy as np
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -36,13 +36,13 @@ model_type = 'RWKV'
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM. # If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
batch_size = 12 batch_size = 12
### Step 4: set learning rate, training 'epochs' ####################################################### ### Step 4: set learning rate, training mini-epochs #######################################################
lr_init = 6e-4 lr_init = 6e-4
lr_final = 1e-5 lr_final = 1e-5
# the 'epoch' here is very short and of fixed length (ctx_len * epoch_length_fixed tokens) # the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
n_epoch = 500 n_epoch = 500
# 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc. # 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc.
epoch_save_frequency = 30 epoch_save_frequency = 30
epoch_save_path = 'trained-' epoch_save_path = 'trained-'
@ -50,7 +50,6 @@ epoch_length_fixed = 10000
######################################################################################################## ########################################################################################################
# import src.utils # import src.utils
# src.utils.set_seed(42) # remember to change seed if you load a model # src.utils.set_seed(42) # remember to change seed if you load a model
@ -71,50 +70,8 @@ num_workers = 0
######################################################################################################## ########################################################################################################
print('loading data... ' + datafile) print('loading data... ' + datafile)
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
class Dataset(Dataset):
def __init__(self, data, ctx_len):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return epoch_length_fixed
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y
train_dataset = Dataset(
open(datafile, "r", encoding=datafile_encoding).read(), ctx_len)
######################################################################################################## ########################################################################################################
# Train model # Train model

Loading…
Cancel
Save