diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index f13e7d5..6d04b74 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -97,10 +97,10 @@ MODEL_NAME = epoch_save_path + str(EPOCH_BEGIN) ######################################################################################################## -if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients. so let's have some warmup if we load a model - warmup_tokens = ctx_len * batch_size * 50 +if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients, so let's have some warmup if we load a model + warmup_tokens = 50 * ctx_len * batch_size // NUM_GPUS else: - warmup_tokens = ctx_len * batch_size * 0 + warmup_tokens = 0 betas = (0.9, 0.99) eps = 1e-8