saves vocab.json, and the model every X epoch

main
BlinkDL 4 years ago
parent 689a6a924d
commit 76e241b71e

@ -8,8 +8,8 @@ from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__)
print('logging to wandb... (comment it if you don\'t have wandb)')
import wandb # comment it if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)')
# import wandb # comment this if you don't have wandb
class TrainerConfig:
max_epochs = 10
@ -22,7 +22,8 @@ class TrainerConfig:
lr_decay = False # linear warmup followed by cosine decay
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper
final_tokens = 260e9 # at which point do we reach lr_final
ckpt_path = None
epoch_save_frequency = 0
epoch_save_path = 'trained-'
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
@ -56,11 +57,6 @@ class Trainer:
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
return run_name
def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path)
def train(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
@ -77,12 +73,11 @@ class Trainer:
pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device) # place data on the correct device
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
logits, loss = model(x, y) # forward the model
_, loss = model(x, y) # forward the model
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
if is_train: # backprop and update the parameters
@ -94,14 +89,15 @@ class Trainer:
if config.lr_decay: # decay the learning rate based on our progress
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
lr_final_factor = config.lr_final / config.learning_rate
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
lr_mult = lr_final_factor + (1 - lr_final_factor) * float(self.tokens) / float(config.warmup_tokens)
progress = 0
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
lr_final_factor = config.lr_final / config.learning_rate
# progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
@ -118,20 +114,17 @@ class Trainer:
if self.avg_loss < 0:
self.avg_loss = now_loss
else:
factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
# factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
factor = 1 / (it + 1)
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
best_loss = float('inf')
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
while True:
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
if self.test_dataset is not None:
test_loss = run_epoch('test')
# supports early stopping based on the test loss, or just save always if no test set is provided
good_model = self.test_dataset is None or test_loss < best_loss
if self.config.ckpt_path is not None and good_model:
best_loss = test_loss
self.save_checkpoint()
run_epoch('train')
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
raw_model = self.model.module if hasattr(self.model, "module") else self.model # DataParallel wrappers keep raw model object in .module
torch.save(raw_model, self.config.epoch_save_path + str(epoch+1) + '.pth')

@ -25,11 +25,24 @@ model_type = 'RWKV'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt"
datafile_encoding = 'utf-8'
# datafile = u"D:\\NLP-Data\\ww100M.txt"
# datafile = u"D:\\NLP-Data\\__2019.txt"
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile = u"V:\\NLP\\enwik8-shift-300.bpe"
# datafile_encoding = 'utf-16'
# datafile = u"V:\\NLP\\simplebooks-shift-utf32.word"
# datafile_encoding = 'utf-32'
datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs
#################################### VERY IMPORTANT ####################################
epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
epoch_save_path = 'trained-'
batch_size = 48 # if you see "CUDA out of memory", reduce this.
# if you have good GPU, increase this.
# use GPU-Z to find the highest value for your VRAM.
########################################################################################
model_level = 'character' # 'character' (recommended) or 'word'
ctx_len = 256 # context length
@ -39,11 +52,9 @@ n_embd = n_head * 64
n_attn = n_embd
n_ffn = n_embd
batch_size = 64
n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length)
lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr
lr_final = 2e-4
lr_final = 1e-5
betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99)
eps = 1e-8
@ -55,6 +66,7 @@ epoch_length_fixed = 10000 # make an 'epoch' very short
rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice
rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english
rwkv_tiny_head = 1 # 1 is good enough. 8 is slow
# n_side_proj = 512 # extra 'side projection', quite useful for BPE models
########################################################################################################
# Load data
@ -76,6 +88,15 @@ class Dataset(Dataset):
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
data_size, vocab_size = len(data), len(unique)
print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) }
@ -108,7 +129,7 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_typ
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn)
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0)
warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

Loading…
Cancel
Save