diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 44e4789..3a0eb3b 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -42,9 +42,10 @@ class MyDataset(Dataset): assert self.samples_per_epoch == 40320 rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") dataset_slot = self.data_size // args.ctx_len - assert MaybeIsPrime(args.magic_prime) - assert args.magic_prime % 3 == 2 - assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 + if args.my_pile_stage != 4: + assert MaybeIsPrime(args.magic_prime) + assert args.magic_prime % 3 == 2 + assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 111d052..b79f96d 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -2,7 +2,7 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import os, math, gc +import os, math, gc, importlib import torch # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) @@ -11,8 +11,9 @@ from torch.nn import functional as F import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy -import deepspeed -from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +if importlib.util.find_spec('deepspeed'): + import deepspeed + from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 0460879..eede9b1 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -111,11 +111,12 @@ if __name__ == "__main__": ######################################################################################################## - import os, warnings, math, datetime, sys, time + import os, warnings, math, datetime, sys, time, importlib import numpy as np import torch from torch.utils.data import DataLoader - import deepspeed + if "deepspeed" in args.strategy: + import deepspeed import pytorch_lightning as pl from pytorch_lightning import seed_everything @@ -223,7 +224,7 @@ if __name__ == "__main__": # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # # Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer -# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) +# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions) # Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer # ############################################################################