main
BlinkDL 3 years ago
parent 8bf7061705
commit f79d082053

@ -17,7 +17,7 @@ class MyDataset(Dataset):
if args.data_type == "binidx":
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
if args.data_file.endswith('/'):
d_all = []
@ -25,12 +25,12 @@ class MyDataset(Dataset):
if p.endswith(".idx"):
d_all += [p[:-4]]
d_all.sort()
print(d_all)
rank_zero_info(d_all)
exit(0)
else:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")
rank_zero_info(f"Data has {self.data_size} tokens.")
if args.my_qa_mask == 1:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
@ -40,7 +40,7 @@ class MyDataset(Dataset):
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
@ -48,15 +48,15 @@ class MyDataset(Dataset):
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
print(f"Data has {self.data_size} tokens.")
rank_zero_info(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = self.data.shape[0]
print(f"Data has {self.data_size} samples.")
rank_zero_info(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img":
self.vocab_size = -1
self.data_size = -1
@ -64,7 +64,7 @@ class MyDataset(Dataset):
self.error_count = 0
else:
if args.data_type == "dummy":
print("Building dummy data...")
rank_zero_info("Building dummy data...")
self.data = ""
for i in range(100000):
aa = (i) % 10000
@ -73,13 +73,13 @@ class MyDataset(Dataset):
self.data += f".{aa}+{bb}={cc}."
else:
self.data = open(args.data_file, "r", encoding=args.data_type).read()
print("Building token list...")
rank_zero_info("Building token list...")
unique = sorted(list(set(self.data)))
self.vocab_size = len(unique)
# print()
# rank_zero_info()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
# rank_zero_info('\n\n')
xx = 0
xxObj = {}
for u in unique:
@ -88,7 +88,7 @@ class MyDataset(Dataset):
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
print("Data has %d tokens, %d vocab size." % (self.data_size, self.vocab_size))
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}

@ -1,9 +1,17 @@
import os, math, time, datetime
import os, math, time, datetime, subprocess
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
def my_save(dd, ff):
if '14b-run1' not in ff:
torch.save(dd, ff)
else:
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True)
class train_callback(pl.Callback):
def __init__(self, args):
@ -100,7 +108,7 @@ class train_callback(pl.Callback):
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
torch.save(
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)
@ -128,7 +136,7 @@ class train_callback(pl.Callback):
else:
to_save_dict = pl_module.state_dict()
try:
torch.save(
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)

@ -5,8 +5,9 @@
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
print("########## work in progress ##########")
rank_zero_info("########## work in progress ##########")
########################################################################################################
#
@ -101,7 +102,7 @@ if __name__ == "__main__":
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
@ -115,7 +116,6 @@ if __name__ == "__main__":
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
@ -138,6 +138,7 @@ if __name__ == "__main__":
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.data_type == "wds_img":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
@ -276,11 +277,11 @@ if __name__ == "__main__":
generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name
print(f"########## Loading {args.load_model}... ##########")
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
print(f"Bad checkpoint {args.load_model}")
rank_zero_info(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint
max_p = args.my_pile_prev_p
if max_p == -1:
@ -288,7 +289,7 @@ if __name__ == "__main__":
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1
print(f"Trying {args.load_model}")
rank_zero_info(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")
if args.load_partial == 1:
@ -302,6 +303,16 @@ if __name__ == "__main__":
args,
callbacks=[train_callback(args)],
)
if trainer.global_rank == 0:
for n in model.state_dict():
shape = model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000

Loading…
Cancel
Save