You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
6.9 KiB
Python
172 lines
6.9 KiB
Python
########################################################################################################
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
########################################################################################################
|
|
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
from torch.nn import functional as F
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch
|
|
from tqdm.auto import tqdm
|
|
import numpy as np
|
|
import logging
|
|
import os
|
|
import datetime
|
|
import sys
|
|
import math
|
|
|
|
# import wandb # comment this if you don't have wandb
|
|
# print('logging to wandb... (comment it if you don\'t have wandb)')
|
|
|
|
logger = logging.getLogger(__name__)
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
log_file = open("mylog.txt", "a")
|
|
|
|
|
|
class TrainerConfig:
|
|
max_epochs = 10
|
|
batch_size = 64
|
|
learning_rate = 4e-4
|
|
betas = (0.9, 0.99)
|
|
eps = 1e-8
|
|
grad_norm_clip = 1.0
|
|
lr_decay = True # linear warmup followed by cosine decay
|
|
warmup_tokens = 0
|
|
final_tokens = 0
|
|
epoch_save_frequency = 0
|
|
epoch_save_path = 'trained-'
|
|
num_workers = 0 # for DataLoader
|
|
|
|
def __init__(self, **kwargs):
|
|
for k, v in kwargs.items():
|
|
setattr(self, k, v)
|
|
|
|
|
|
class Trainer:
|
|
|
|
def __init__(self, model, train_dataset, test_dataset, config):
|
|
self.model = model
|
|
self.train_dataset = train_dataset
|
|
self.test_dataset = test_dataset
|
|
self.config = config
|
|
self.avg_loss = -1
|
|
self.steps = 0
|
|
|
|
if 'wandb' in sys.modules:
|
|
cfg = model.config
|
|
for k in config.__dict__:
|
|
setattr(cfg, k, config.__dict__[k]) # combine cfg
|
|
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
|
|
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
|
|
|
|
self.device = 'cpu'
|
|
if torch.cuda.is_available(): # take over whatever gpus are on the system
|
|
self.device = torch.cuda.current_device()
|
|
|
|
def get_run_name(self):
|
|
raw_model = self.model.module if hasattr(
|
|
self.model, "module") else self.model
|
|
cfg = raw_model.config
|
|
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
|
|
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
|
|
return run_name
|
|
|
|
def train(self):
|
|
model, config = self.model, self.config
|
|
raw_model = model.module if hasattr(self.model, "module") else model
|
|
optimizer = raw_model.configure_optimizers(config)
|
|
|
|
def run_epoch(split):
|
|
is_train = split == 'train'
|
|
model.train(is_train)
|
|
data = self.train_dataset if is_train else self.test_dataset
|
|
|
|
if config.num_workers > 0:
|
|
loader = DataLoader(data, shuffle=False, pin_memory=True,
|
|
batch_size=config.batch_size,
|
|
num_workers=config.num_workers)
|
|
else:
|
|
loader = DataLoader(data, shuffle=False,
|
|
batch_size=config.batch_size,
|
|
num_workers=config.num_workers)
|
|
|
|
pbar = tqdm(enumerate(loader), total=len(
|
|
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
|
|
|
for it, (x, y) in pbar:
|
|
x = x.to(self.device) # place data on the correct device
|
|
y = y.to(self.device)
|
|
|
|
with torch.set_grad_enabled(is_train):
|
|
_, loss = model(x, y) # forward the model
|
|
|
|
if is_train: # backprop and update the parameters
|
|
model.zero_grad()
|
|
loss.backward()
|
|
|
|
if config.grad_norm_clip > 0:
|
|
torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(), config.grad_norm_clip)
|
|
|
|
optimizer.step()
|
|
|
|
if config.lr_decay: # decay the learning rate based on our progress
|
|
# number of tokens processed this step (i.e. label is not -100)
|
|
self.tokens += (y >= 0).sum()
|
|
lr_final_factor = config.lr_final / config.learning_rate
|
|
if self.tokens < config.warmup_tokens:
|
|
# linear warmup
|
|
lr_mult = lr_final_factor + \
|
|
(1 - lr_final_factor) * float(self.tokens) / \
|
|
float(config.warmup_tokens)
|
|
progress = 0
|
|
else:
|
|
# exponential learning rate decay
|
|
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
|
|
if progress >= 1:
|
|
lr_mult = lr_final_factor
|
|
else:
|
|
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
|
|
lr = config.learning_rate * lr_mult
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr
|
|
else:
|
|
lr = config.learning_rate
|
|
|
|
now_loss = loss.item() # report progress
|
|
self.lr = lr
|
|
|
|
if 'wandb' in sys.modules:
|
|
wandb.log({"loss": now_loss},
|
|
step=self.steps * self.config.batch_size)
|
|
self.steps += 1
|
|
|
|
if self.avg_loss < 0:
|
|
self.avg_loss = now_loss
|
|
else:
|
|
factor = 1 / (it + 1)
|
|
self.avg_loss = self.avg_loss * \
|
|
(1.0 - factor) + now_loss * factor
|
|
pbar.set_description(
|
|
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
|
|
|
|
self.tokens = 0 # counter used for learning rate decay
|
|
for epoch in range(config.max_epochs):
|
|
|
|
run_epoch('train')
|
|
|
|
log_file.write(
|
|
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
|
|
log_file.flush()
|
|
|
|
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
|
|
# DataParallel wrappers keep raw model object in .module
|
|
raw_model = self.model.module if hasattr(
|
|
self.model, "module") else self.model
|
|
torch.save(raw_model.state_dict(),
|
|
self.config.epoch_save_path + str(epoch+1) + '.pth')
|