@ -8,8 +8,8 @@ from torch.optim.lr_scheduler import LambdaLR
from torch . utils . data . dataloader import DataLoader
logger = logging . getLogger ( __name__ )
print ( ' logging to wandb... (comment it if you don \' t have wandb) ' )
import wandb # comment it if you don't have wandb
# print('logging to wandb... (comment it if you don\'t have wandb)' )
# import wandb # comment this if you don't have wandb
class TrainerConfig :
max_epochs = 10
@ -22,7 +22,8 @@ class TrainerConfig:
lr_decay = False # linear warmup followed by cosine decay
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper
final_tokens = 260e9 # at which point do we reach lr_final
ckpt_path = None
epoch_save_frequency = 0
epoch_save_path = ' trained- '
num_workers = 0 # for DataLoader
def __init__ ( self , * * kwargs ) :
@ -56,11 +57,6 @@ class Trainer:
run_name = str ( cfg . vocab_size ) + ' - ' + str ( cfg . ctx_len ) + ' - ' + cfg . model_type + ' - ' + str ( cfg . n_layer ) + ' - ' + str ( cfg . n_embd )
return run_name
def save_checkpoint ( self ) : # DataParallel wrappers keep raw model object in .module attribute
raw_model = self . model . module if hasattr ( self . model , " module " ) else self . model
logger . info ( " saving %s " , self . config . ckpt_path )
torch . save ( raw_model . state_dict ( ) , self . config . ckpt_path )
def train ( self ) :
model , config = self . model , self . config
raw_model = model . module if hasattr ( self . model , " module " ) else model
@ -77,12 +73,11 @@ class Trainer:
pbar = tqdm ( enumerate ( loader ) , total = len ( loader ) , bar_format = ' {l_bar} {bar:10} {r_bar} {bar:-10b} ' ) if is_train else enumerate ( loader )
for it , ( x , y ) in pbar :
x = x . to ( self . device ) # place data on the correct device
y = y . to ( self . device )
with torch . set_grad_enabled ( is_train ) :
logits , loss = model ( x , y ) # forward the model
_ , loss = model ( x , y ) # forward the model
loss = loss . mean ( ) # collapse all losses if they are scattered on multiple gpus
if is_train : # backprop and update the parameters
@ -94,14 +89,15 @@ class Trainer:
if config . lr_decay : # decay the learning rate based on our progress
self . tokens + = ( y > = 0 ) . sum ( ) # number of tokens processed this step (i.e. label is not -100)
lr_final_factor = config . lr_final / config . learning_rate
if self . tokens < config . warmup_tokens :
# linear warmup
lr_mult = float ( self . tokens ) / float ( max ( 1 , config . warmup_tokens ) )
lr_mult = lr_final_factor + ( 1 - lr_final_factor ) * float ( self . tokens ) / float ( config . warmup_tokens )
progress = 0
else :
# cosine learning rate decay
progress = float ( self . tokens - config . warmup_tokens ) / float ( max ( 1 , config . final_tokens - config . warmup_tokens ) )
lr_final_factor = config . lr_final / config . learning_rate
# progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR
lr_mult = ( 0.5 + lr_final_factor / 2 ) + ( 0.5 - lr_final_factor / 2 ) * math . cos ( math . pi * progress ) # better 1.0 ~ 0.1
lr = config . learning_rate * lr_mult
for param_group in optimizer . param_groups :
@ -118,20 +114,17 @@ class Trainer:
if self . avg_loss < 0 :
self . avg_loss = now_loss
else :
factor = max ( 1.0 / 300 , 1.0 / math . sqrt ( it + 1 ) )
# factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1))
factor = 1 / ( it + 1 )
self . avg_loss = self . avg_loss * ( 1.0 - factor ) + now_loss * factor
pbar . set_description ( f " epoch { epoch + 1 } progress { progress * 100.0 : .2f } % iter { it } : ppl { math . exp ( self . avg_loss ) : .2f } loss { self . avg_loss : .4f } lr { lr : e } " )
best_loss = float ( ' inf ' )
self . tokens = 0 # counter used for learning rate decay
for epoch in range ( config . max_epochs ) :
while True :
self . tokens = 0 # counter used for learning rate decay
for epoch in range ( config . max_epochs ) :
run_epoch ( ' train ' )
if self . test_dataset is not None :
test_loss = run_epoch ( ' test ' )
# supports early stopping based on the test loss, or just save always if no test set is provided
good_model = self . test_dataset is None or test_loss < best_loss
if self . config . ckpt_path is not None and good_model :
best_loss = test_loss
self . save_checkpoint ( )
run_epoch ( ' train ' )
if ( self . config . epoch_save_frequency > 0 and epoch % self . config . epoch_save_frequency == 0 ) or ( epoch == config . max_epochs - 1 ) :
raw_model = self . model . module if hasattr ( self . model , " module " ) else self . model # DataParallel wrappers keep raw model object in .module
torch . save ( raw_model , self . config . epoch_save_path + str ( epoch + 1 ) + ' .pth ' )