supports megatron bin+idx format

main
BlinkDL 3 years ago
parent aa67870849
commit f79137b524

@ -10,7 +10,6 @@ import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
@ -18,27 +17,36 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' or 'fp16'
ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
# your trained model
MODEL_NAME = 'trained-1'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
### Step 2: set vocab & context ########################################################################
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
CHAR_MODE = True # True False
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
if CHAR_MODE:
### example 1: char-level model
MODEL_NAME = 'trained-500' # your trained model
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
context = "\nIn the" # your prompt
else:
### example 2: BPE-level model
MODEL_NAME = 'trained-7773'
WORD_NAME = ['model-vocab.json', 'model-merges.txt'] # [vocab, merge]
UNKNOWN_CHAR = None
context = 'A'
### Step 2: set context ################################################################################
### Step 3: other config ###############################################################################
context = "\nIn the" # ==> this is your prompt
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500
@ -50,6 +58,7 @@ top_p_newline = 0.9
########################################################################################################
print(f'Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
@ -63,7 +72,10 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()
src_len = len(context)
if tokenizer.charMode:
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
else:
ctx = tokenizer.tokenizer.encode(context)
print(('-' * 30) + context, end='')
model.clear()
@ -94,7 +106,10 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
top_p_usual=top_p, top_p_newline=top_p_newline)
char = char.item()
if tokenizer.charMode:
print(tokenizer.itos[int(char)], end='', flush=True)
else:
print(tokenizer.tokenizer.decode(int(char)), end='', flush=True)
ctx += [char]
t_end = time.time_ns()
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')

@ -0,0 +1,216 @@
from lib2to3.pgen2 import token
import os
import torch
import numpy as np
import shutil
import struct
from functools import lru_cache
from itertools import accumulate
def print_rank_0(*message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(*message, flush=True)
else:
print(*message, flush=True)
def _warmup_mmap_file(path):
pass
# with open(path, "rb") as stream:
# while stream.read(100 * 1024 * 1024):
# pass
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"
def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
# Little endian unsigned 64 Bit integer
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
# Little endian unsigned 8 Bit integer
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError(
"Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(
data_file_path(path)
)

@ -12,7 +12,7 @@ from deepspeed.ops.adam import FusedAdam
logger = logging.getLogger(__name__)
RWKV_HEAD_QK_DIM = 256
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
########################################################################################################

@ -9,7 +9,7 @@ import math, os
from torch.nn import functional as F
import torch.nn as nn
RWKV_HEAD_QK_DIM = 256
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs

@ -17,6 +17,16 @@ from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.data = data
if 'MMapIndexedDataset' in str(type(self.data)):
self.vocab_size = 253 # your vocab_size
print('current vocab size = ', self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2
self.item_cnt = len(self.data)
else:
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
@ -36,19 +46,21 @@ class Dataset(Dataset):
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data
def __len__(self):
return self.epoch_length_fixed // NUM_GPUS
def __getitem__(self, idx):
# cheat: pick a random spot in dataset
if 'MMapIndexedDataset' in str(type(self.data)):
i = np.random.randint(0, self.data_size - (self.ctx_len + 1))
dix = self.data.get(idx=0, offset=i, length=self.ctx_len + 1).astype(int)
else:
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y
@ -56,6 +68,12 @@ class Dataset(Dataset):
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)):
self.charMode = False
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
else:
self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)
@ -84,10 +102,13 @@ class TOKENIZER():
probs = F.softmax(torch.tensor(out), dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
sorted_probs, s_index = torch.sort(probs, descending=True)

@ -35,6 +35,7 @@ import logging, types
from src.utils import Dataset
import torch
import numpy as np
from src.binidx import MMapIndexedDataset # for the Megatron-LM 'binidx' format
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -46,8 +47,10 @@ torch.backends.cuda.matmul.allow_tf32 = True
### Step 1: set training data ##########################################################################
datafile = "../data/enwik8" # your data
datafile_encoding = 'utf-8'
# datafile_encoding = 'utf-16le'
datafile_encoding = 'utf-8' # 'utf-8' 'utf-16le' 'binidx'
# datafile = './my-gpt_seq_document'
# datafile_encoding = 'binidx'
### Step 2: set model size #############################################################################
@ -65,7 +68,7 @@ model_type = 'RWKV'
### Step 3: set batch size #############################################################################
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
batch_size = 12
batch_size = 12 * NUM_GPUS
assert (batch_size % NUM_GPUS == 0)
### Step 4: set learning rate, number of mini-epochs #######################################################
@ -109,8 +112,11 @@ num_workers = 1 # DataLoader worker. I only tested num_workers = 1
########################################################################################################
print('loading data... ' + datafile)
train_dataset = Dataset(open(
if datafile_encoding != 'binidx':
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
else:
train_dataset = Dataset(MMapIndexedDataset(datafile), ctx_len, epoch_length_fixed)
########################################################################################################
# Train model

Loading…
Cancel
Save