diff --git a/RWKV-v3/src/trainer.py b/RWKV-v3/src/trainer.py index 3a1dbc6..418d72e 100644 --- a/RWKV-v3/src/trainer.py +++ b/RWKV-v3/src/trainer.py @@ -125,11 +125,12 @@ class Trainer: float(config.warmup_tokens) progress = 0 else: - # cosine learning rate decay - progress = float(self.tokens - config.warmup_tokens) / float( - max(1, config.final_tokens - config.warmup_tokens)) - lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / - 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 + # exponential learning rate decay + progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + if progress >= 1: + lr_mult = lr_final_factor + else: + lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1)) lr = config.learning_rate * lr_mult for param_group in optimizer.param_groups: param_group['lr'] = lr diff --git a/RWKV-v3/train.py b/RWKV-v3/train.py index 6f1b26d..a8f42dc 100644 --- a/RWKV-v3/train.py +++ b/RWKV-v3/train.py @@ -4,11 +4,11 @@ import os -if True: # True False ---> Set to False if you don't understand it - print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n") - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - import src.utils - src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples) +# if False: # True False ---> Set to False if you don't understand it +# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n") +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" +# import src.utils +# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples) import logging import datetime @@ -53,6 +53,16 @@ model_type = 'RWKV' batch_size = 12 ### Step 4: set learning rate, number of mini-epochs ####################################################### +# By default we are using exponential LR decay. +# +# Here are my suggestions for training a good model. +# Let's say you will train a L6-D512 model. +# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until the improvement of loss become slow. +# 2) Ctrl+C to stop the run. +# 3) Set lr_init = 8e-4, lr_final = 1e-5, warmup_tokens = ctx_len * batch_size * 50, betas = (0.9, 0.999) +# 4) Search for "torch.load" here and modify it to load the partially-trained model. Continue the training. +# +# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4. lr_init = 8e-4 # we can use larger lr because of preLN lr_final = 1e-5 @@ -68,7 +78,7 @@ epoch_save_path = 'trained-' ######################################################################################################## grad_norm_clip = 1.0 -warmup_tokens = 0 +warmup_tokens = ctx_len * batch_size * 0 betas = (0.9, 0.99) eps = 4e-9 @@ -91,7 +101,7 @@ if __name__ == '__main__': model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda() - ### load a trained model + ### ---> load a trained model <--- # m2 = torch.load('trained-61.pth') # model.load_state_dict(m2)