diff --git a/src/trainer.py b/src/trainer.py index 56af8e1..5f88fcc 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -8,8 +8,8 @@ from torch.optim.lr_scheduler import LambdaLR from torch.utils.data.dataloader import DataLoader logger = logging.getLogger(__name__) -print('logging to wandb... (comment it if you don\'t have wandb)') -import wandb # comment it if you don't have wandb +# print('logging to wandb... (comment it if you don\'t have wandb)') +# import wandb # comment this if you don't have wandb class TrainerConfig: max_epochs = 10 @@ -22,7 +22,8 @@ class TrainerConfig: lr_decay = False # linear warmup followed by cosine decay warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper final_tokens = 260e9 # at which point do we reach lr_final - ckpt_path = None + epoch_save_frequency = 0 + epoch_save_path = 'trained-' num_workers = 0 # for DataLoader def __init__(self, **kwargs): @@ -56,11 +57,6 @@ class Trainer: run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) return run_name - def save_checkpoint(self): # DataParallel wrappers keep raw model object in .module attribute - raw_model = self.model.module if hasattr(self.model, "module") else self.model - logger.info("saving %s", self.config.ckpt_path) - torch.save(raw_model.state_dict(), self.config.ckpt_path) - def train(self): model, config = self.model, self.config raw_model = model.module if hasattr(self.model, "module") else model @@ -77,12 +73,11 @@ class Trainer: pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) for it, (x, y) in pbar: - x = x.to(self.device) # place data on the correct device y = y.to(self.device) with torch.set_grad_enabled(is_train): - logits, loss = model(x, y) # forward the model + _, loss = model(x, y) # forward the model loss = loss.mean() # collapse all losses if they are scattered on multiple gpus if is_train: # backprop and update the parameters @@ -94,14 +89,15 @@ class Trainer: if config.lr_decay: # decay the learning rate based on our progress self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) + lr_final_factor = config.lr_final / config.learning_rate if self.tokens < config.warmup_tokens: # linear warmup - lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) + lr_mult = lr_final_factor + (1 - lr_final_factor) * float(self.tokens) / float(config.warmup_tokens) progress = 0 else: # cosine learning rate decay progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) - lr_final_factor = config.lr_final / config.learning_rate + # progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 lr = config.learning_rate * lr_mult for param_group in optimizer.param_groups: @@ -118,20 +114,17 @@ class Trainer: if self.avg_loss < 0: self.avg_loss = now_loss else: - factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1)) + # factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1)) + factor = 1 / (it + 1) self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") - best_loss = float('inf') - self.tokens = 0 # counter used for learning rate decay - for epoch in range(config.max_epochs): + while True: + self.tokens = 0 # counter used for learning rate decay + for epoch in range(config.max_epochs): - run_epoch('train') - if self.test_dataset is not None: - test_loss = run_epoch('test') - - # supports early stopping based on the test loss, or just save always if no test set is provided - good_model = self.test_dataset is None or test_loss < best_loss - if self.config.ckpt_path is not None and good_model: - best_loss = test_loss - self.save_checkpoint() + run_epoch('train') + + if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): + raw_model = self.model.module if hasattr(self.model, "module") else self.model # DataParallel wrappers keep raw model object in .module + torch.save(raw_model, self.config.epoch_save_path + str(epoch+1) + '.pth') diff --git a/train.py b/train.py index 3e71ef1..61947dd 100644 --- a/train.py +++ b/train.py @@ -25,11 +25,24 @@ model_type = 'RWKV' datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" datafile_encoding = 'utf-8' # datafile = u"D:\\NLP-Data\\ww100M.txt" +# datafile = u"D:\\NLP-Data\\__2019.txt" # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" +# datafile = u"V:\\NLP\\enwik8-shift-300.bpe" # datafile_encoding = 'utf-16' +# datafile = u"V:\\NLP\\simplebooks-shift-utf32.word" +# datafile_encoding = 'utf-32' datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs +#################################### VERY IMPORTANT #################################### +epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc. +epoch_save_path = 'trained-' + +batch_size = 48 # if you see "CUDA out of memory", reduce this. + # if you have good GPU, increase this. + # use GPU-Z to find the highest value for your VRAM. +######################################################################################## + model_level = 'character' # 'character' (recommended) or 'word' ctx_len = 256 # context length @@ -39,11 +52,9 @@ n_embd = n_head * 64 n_attn = n_embd n_ffn = n_embd -batch_size = 64 - n_epoch = 50 # the 'epoch' here is actually very short (and of fixed length) lr_init = 8e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr -lr_final = 2e-4 +lr_final = 1e-5 betas = (0.9, 0.999) if model_type == 'RWKV' else (0.9, 0.99) eps = 1e-8 @@ -55,6 +66,7 @@ epoch_length_fixed = 10000 # make an 'epoch' very short rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice rwkv_tiny_attn = 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english rwkv_tiny_head = 1 # 1 is good enough. 8 is slow +# n_side_proj = 512 # extra 'side projection', quite useful for BPE models ######################################################################################################## # Load data @@ -76,6 +88,15 @@ class Dataset(Dataset): # for u in unique: # print(u, end=' ') # print('\n\n') + + xx = 0 + xxObj = {} + for u in unique: + xxObj[xx] = u + xx += 1 + with open('vocab.json', "w", encoding="utf-16") as vocab_file: + vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) + data_size, vocab_size = len(data), len(unique) print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) self.stoi = { ch:i for i,ch in enumerate(unique) } @@ -108,7 +129,7 @@ model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_typ print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, - warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0) + warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) trainer = Trainer(model, train_dataset, None, tconf) trainer.train()