|
|
|
@ -2,7 +2,7 @@
|
|
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
import json, math, random
|
|
|
|
import json, math, random, os, sys
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
@ -16,21 +16,31 @@ class MyDataset(Dataset):
|
|
|
|
self.args = args
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
|
|
if args.data_type == "binidx":
|
|
|
|
if args.data_type == "binidx":
|
|
|
|
self.data = MMapIndexedDataset(args.data_file)
|
|
|
|
|
|
|
|
self.vocab_size = args.vocab_size
|
|
|
|
self.vocab_size = args.vocab_size
|
|
|
|
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
|
|
|
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.data_file.endswith('/'):
|
|
|
|
|
|
|
|
d_all = []
|
|
|
|
|
|
|
|
for p in os.listdir(args.data_file):
|
|
|
|
|
|
|
|
if p.endswith(".idx"):
|
|
|
|
|
|
|
|
d_all += [p[:-4]]
|
|
|
|
|
|
|
|
d_all.sort()
|
|
|
|
|
|
|
|
print(d_all)
|
|
|
|
|
|
|
|
exit(0)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.data = MMapIndexedDataset(args.data_file)
|
|
|
|
self.data_size = len(self.data._bin_buffer) // 2
|
|
|
|
self.data_size = len(self.data._bin_buffer) // 2
|
|
|
|
print(f"Data has {self.data_size} tokens.")
|
|
|
|
print(f"Data has {self.data_size} tokens.")
|
|
|
|
|
|
|
|
|
|
|
|
if args.my_pile_stage > 0:
|
|
|
|
if args.my_pile_stage > 0:
|
|
|
|
assert self.data_size == 332115325534 and self.vocab_size == 50277
|
|
|
|
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
|
|
|
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
|
|
|
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
|
|
|
assert self.samples_per_epoch == 40320
|
|
|
|
assert self.samples_per_epoch == 40320
|
|
|
|
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
|
|
|
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
|
|
|
dataset_slot = self.data_size // args.ctx_len
|
|
|
|
dataset_slot = self.data_size // args.ctx_len
|
|
|
|
assert MaybeIsPrime(args.magic_prime)
|
|
|
|
assert MaybeIsPrime(args.magic_prime)
|
|
|
|
assert args.magic_prime % 3 == 2
|
|
|
|
assert args.magic_prime % 3 == 2
|
|
|
|
assert args.magic_prime / dataset_slot > 0.999999 and args.magic_prime / dataset_slot <= 1
|
|
|
|
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
|
|
|
|
elif args.data_type == "numpy":
|
|
|
|
elif args.data_type == "numpy":
|
|
|
|
self.data = np.load(args.data_file).astype("int")
|
|
|
|
self.data = np.load(args.data_file).astype("int")
|
|
|
|
self.vocab_size = args.vocab_size
|
|
|
|
self.vocab_size = args.vocab_size
|
|
|
|
|