You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RWKV-LM/RWKV-v4neo/src/dataset.py

70 lines
2.8 KiB
Python

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from pytorch_lightning.utilities import rank_zero_info
from .binidx import MMapIndexedDataset
class MyDataset(Dataset):
def __init__(self, args):
self.args = args
if args.data_type == "binidx":
self.data = MMapIndexedDataset(args.data_file)
self.vocab_size = args.vocab_size
print("current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2
print(f"data has {self.data_size} tokens.")
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
print("current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
print(f"data has {self.data_size} tokens.")
else:
self.data = open(args.data_file, "r", encoding=args.data_type).read()
print("building token list...", end=" ")
unique = sorted(list(set(self.data)))
self.vocab_size = len(unique)
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
print("data has %d tokens, %d unique." % (self.data_size, self.vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
def __len__(self):
return self.args.epoch_steps * int(self.args.devices) * self.args.micro_bsz
def __getitem__(self, idx):
#
# we are cheating: pick a random spot in dataset
#
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
i = np.random.randint(0, self.data_size - req_len)
if "MMapIndexedDataset" in str(type(self.data)):
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int)
elif "numpy" in str(type(self.data)):
dix = self.data[i : i + req_len]
else:
dix = [self.stoi[s] for s in self.data[i : i + req_len]]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y