|
|
|
|
@ -111,11 +111,12 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
import os, warnings, math, datetime, sys, time
|
|
|
|
|
import os, warnings, math, datetime, sys, time, importlib
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
import deepspeed
|
|
|
|
|
if "deepspeed" in args.strategy:
|
|
|
|
|
import deepspeed
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
|
|
|
|
|
|
@ -223,7 +224,7 @@ if __name__ == "__main__":
|
|
|
|
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
|
|
|
|
#
|
|
|
|
|
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
|
|
|
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
|
|
|
|
|
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
|
|
|
|
|
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
|
|
|
|
|
#
|
|
|
|
|
############################################################################
|
|
|
|
|
|