main
BlinkDL 3 years ago
parent 8e99ac1138
commit 6d4dec7288

@ -42,9 +42,10 @@ class MyDataset(Dataset):
assert self.samples_per_epoch == 40320 assert self.samples_per_epoch == 40320
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime) if args.my_pile_stage != 4:
assert args.magic_prime % 3 == 2 assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy": elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int") self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import os, math, gc import os, math, gc, importlib
import torch import torch
# torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True) # torch._C._jit_set_profiling_mode(True)
@ -11,8 +11,9 @@ from torch.nn import functional as F
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed if importlib.util.find_spec('deepspeed'):
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam

@ -111,11 +111,12 @@ if __name__ == "__main__":
######################################################################################################## ########################################################################################################
import os, warnings, math, datetime, sys, time import os, warnings, math, datetime, sys, time, importlib
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import deepspeed if "deepspeed" in args.strategy:
import deepspeed
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
@ -223,7 +224,7 @@ if __name__ == "__main__":
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
# #
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer # Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) # Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer # Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
# #
############################################################################ ############################################################################

Loading…
Cancel
Save