main
BlinkDL 3 years ago
parent 935d8d3e87
commit cf340264dc

@ -49,6 +49,58 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
# Write Magic string so we can check the file format then opening it again.
self._file.write(cls._HDR_MAGIC)
# Write version number
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", 1))
# Little endian unsigned 8 Bit integer
self._file.write(struct.pack("<B", code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(sizes)))
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)

@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json, math, random
import json, math, random, os, sys
import numpy as np
import torch
from torch.utils.data import Dataset
@ -16,21 +16,31 @@ class MyDataset(Dataset):
self.args = args
if args.data_type == "binidx":
self.data = MMapIndexedDataset(args.data_file)
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
if args.data_file.endswith('/'):
d_all = []
for p in os.listdir(args.data_file):
if p.endswith(".idx"):
d_all += [p[:-4]]
d_all.sort()
print(d_all)
exit(0)
else:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")
if args.my_pile_stage > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.999999 and args.magic_prime / dataset_slot <= 1
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size

@ -97,6 +97,14 @@ class train_callback(pl.Callback):
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
torch.save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args

@ -99,6 +99,7 @@ if __name__ == "__main__":
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
@ -145,6 +146,7 @@ if __name__ == "__main__":
os.makedirs(args.proj_dir)
if args.my_pile_stage > 0:
magic_prime_bak = args.magic_prime
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
@ -162,6 +164,9 @@ if __name__ == "__main__":
elif args.ctx_len == 4096:
args.my_pile_shift = 768
if magic_prime_bak > 0:
args.magic_prime = magic_prime_bak
args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2:

Loading…
Cancel
Save