diff --git a/RWKV-v4/run.py b/RWKV-v4/run.py index ef2ea47..cd557c7 100644 --- a/RWKV-v4/run.py +++ b/RWKV-v4/run.py @@ -10,7 +10,6 @@ import copy import torch from torch.nn import functional as F from src.utils import TOKENIZER, Dataset -from src.model_run import RWKV_RNN torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True @@ -18,27 +17,36 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200) ### Step 1: set model ################################################################################## -os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) +os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' or 'fp16' ctx_len = 1024 n_layer = 6 n_embd = 512 -model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' +model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' -# your trained model -MODEL_NAME = 'trained-1' -WORD_NAME = 'vocab' # the .json vocab (generated by train.py +### Step 2: set vocab & context ######################################################################## -# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- -# --> all unknown tokens in your context will be denoted by it <-- -UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity +CHAR_MODE = True # True False -RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' -DEBUG_DEBUG = False # True False - show softmax output +if CHAR_MODE: + ### example 1: char-level model + MODEL_NAME = 'trained-500' # your trained model + WORD_NAME = 'vocab' # the .json vocab (generated by train.py) + # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- + # --> all unknown tokens in your context will be denoted by it <-- + UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity + context = "\nIn the" # your prompt +else: + ### example 2: BPE-level model + MODEL_NAME = 'trained-7773' + WORD_NAME = ['model-vocab.json', 'model-merges.txt'] # [vocab, merge] + UNKNOWN_CHAR = None + context = 'A' -### Step 2: set context ################################################################################ +### Step 3: other config ############################################################################### -context = "\nIn the" # ==> this is your prompt +RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' +DEBUG_DEBUG = False # True False - show softmax output NUM_TRIALS = 999 LENGTH_PER_TRIAL = 500 @@ -50,6 +58,7 @@ top_p_newline = 0.9 ######################################################################################################## print(f'Loading {MODEL_NAME}...') +from src.model_run import RWKV_RNN model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) @@ -63,7 +72,10 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): t_begin = time.time_ns() src_len = len(context) - ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] + if tokenizer.charMode: + ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] + else: + ctx = tokenizer.tokenizer.encode(context) print(('-' * 30) + context, end='') model.clear() @@ -94,7 +106,10 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_p_usual=top_p, top_p_newline=top_p_newline) char = char.item() - print(tokenizer.itos[int(char)], end='', flush=True) + if tokenizer.charMode: + print(tokenizer.itos[int(char)], end='', flush=True) + else: + print(tokenizer.tokenizer.decode(int(char)), end='', flush=True) ctx += [char] t_end = time.time_ns() print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') diff --git a/RWKV-v4/src/binidx.py b/RWKV-v4/src/binidx.py new file mode 100644 index 0000000..43fefaa --- /dev/null +++ b/RWKV-v4/src/binidx.py @@ -0,0 +1,216 @@ +from lib2to3.pgen2 import token +import os +import torch +import numpy as np +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + +def print_rank_0(*message): + """If distributed is initialized print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(*message, flush=True) + else: + print(*message, flush=True) + +def _warmup_mmap_file(path): + pass + # with open(path, "rb") as stream: + # while stream.read(100 * 1024 * 1024): + # pass + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16, +} + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + +def index_file_path(prefix_path): + return prefix_path + ".idx" + +def data_file_path(prefix_path): + return prefix_path + ".bin" + +class MMapIndexedDataset(torch.utils.data.Dataset): + class Index(object): + _HDR_MAGIC = b"MMIDIDX\x00\x00" + + def __init__(self, path, skip_warmup=False): + with open(path, "rb") as stream: + magic_test = stream.read(9) + assert self._HDR_MAGIC == magic_test, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + # Little endian unsigned 64 Bit integer + version = struct.unpack("