main
BlinkDL 3 years ago
parent 8e99ac1138
commit 6d4dec7288

@ -42,6 +42,7 @@ class MyDataset(Dataset):
assert self.samples_per_epoch == 40320 assert self.samples_per_epoch == 40320
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len dataset_slot = self.data_size // args.ctx_len
if args.my_pile_stage != 4:
assert MaybeIsPrime(args.magic_prime) assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2 assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import os, math, gc import os, math, gc, importlib
import torch import torch
# torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True) # torch._C._jit_set_profiling_mode(True)
@ -11,6 +11,7 @@ from torch.nn import functional as F
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
if importlib.util.find_spec('deepspeed'):
import deepspeed import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam

@ -111,10 +111,11 @@ if __name__ == "__main__":
######################################################################################################## ########################################################################################################
import os, warnings, math, datetime, sys, time import os, warnings, math, datetime, sys, time, importlib
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
if "deepspeed" in args.strategy:
import deepspeed import deepspeed
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
@ -223,7 +224,7 @@ if __name__ == "__main__":
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
# #
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer # Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) # Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer # Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
# #
############################################################################ ############################################################################

Loading…
Cancel
Save