You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
2.8 KiB
Python
70 lines
2.8 KiB
Python
########################################################################################################
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
########################################################################################################
|
|
|
|
import json
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from pytorch_lightning.utilities import rank_zero_info
|
|
from .binidx import MMapIndexedDataset
|
|
|
|
|
|
class MyDataset(Dataset):
|
|
def __init__(self, args):
|
|
self.args = args
|
|
|
|
if args.data_type == "binidx":
|
|
self.data = MMapIndexedDataset(args.data_file)
|
|
self.vocab_size = args.vocab_size
|
|
print("current vocab size =", self.vocab_size, "(make sure it's correct)")
|
|
self.data_size = len(self.data._bin_buffer) // 2
|
|
print(f"data has {self.data_size} tokens.")
|
|
elif args.data_type == "numpy":
|
|
self.data = np.load(args.data_file).astype("int")
|
|
self.vocab_size = args.vocab_size
|
|
print("current vocab size =", self.vocab_size, "(make sure it's correct)")
|
|
self.data_size = len(self.data)
|
|
print(f"data has {self.data_size} tokens.")
|
|
else:
|
|
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
|
print("building token list...", end=" ")
|
|
unique = sorted(list(set(self.data)))
|
|
self.vocab_size = len(unique)
|
|
# print()
|
|
# for u in unique:
|
|
# print(u, end=' ')
|
|
# print('\n\n')
|
|
xx = 0
|
|
xxObj = {}
|
|
for u in unique:
|
|
xxObj[xx] = u
|
|
xx += 1
|
|
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
|
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
|
self.data_size = len(self.data)
|
|
print("data has %d tokens, %d unique." % (self.data_size, self.vocab_size))
|
|
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
|
self.itos = {i: ch for i, ch in enumerate(unique)}
|
|
|
|
def __len__(self):
|
|
return self.args.epoch_steps * int(self.args.devices) * self.args.micro_bsz
|
|
|
|
def __getitem__(self, idx):
|
|
#
|
|
# we are cheating: pick a random spot in dataset
|
|
#
|
|
ctx_len = self.args.ctx_len
|
|
req_len = ctx_len + 1
|
|
i = np.random.randint(0, self.data_size - req_len)
|
|
if "MMapIndexedDataset" in str(type(self.data)):
|
|
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int)
|
|
elif "numpy" in str(type(self.data)):
|
|
dix = self.data[i : i + req_len]
|
|
else:
|
|
dix = [self.stoi[s] for s in self.data[i : i + req_len]]
|
|
|
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
|
y = torch.tensor(dix[1:], dtype=torch.long)
|
|
return x, y
|