diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index 86f909a..dc43838 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -62,17 +62,17 @@ elif TOKEN_MODE == "pile": # n_embd = 1024 # ctx_len = 1024 - # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040' - # n_layer = 24 - # n_embd = 2048 - # ctx_len = 1024 - - MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537' - n_layer = 32 - n_embd = 2560 + MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-1B5-20220903-8040' + n_layer = 24 + n_embd = 2048 ctx_len = 1024 -os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 + # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20220925-4537' + # n_layer = 32 + # n_embd = 2560 + # ctx_len = 1024 + +os.environ["RWKV_FLOAT_MODE"] = "fp32" # currently only supprts fp32 (it can do bf16 and fp16. just wait a bit... busy these days) os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' @@ -83,12 +83,13 @@ model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' # context = 'A' # context = "\nIn the" # context = '\nSugar:' -# context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." +context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." -context = "\n深圳是" # test Chinese -context = "\n東京は" # test Japanese +# context = "\n深圳是" # test Chinese +# context = "\n東京は" # test Japanese -# context = ''' # A good prompt for chatbot +###### A good prompt for chatbot ###### +# context = ''' # The following is a conversation between a highly knowledgeable and intelligent AI assistant, called RWKV, and a human user, called User. In the following interactions, User and RWKV will converse in natural language, and RWKV will do its best to answer User’s questions. RWKV was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. The conversation begins. # User: OK RWKV, I’m going to start by quizzing you with a few warm-up questions. Who is currently the president of the USA? @@ -120,7 +121,7 @@ DEBUG_DEBUG = False # True False --> show softmax output ######################################################################################################## -print(f"Loading {MODEL_NAME}...") +print(f'\nUsing {os.environ["RWKV_RUN_DEVICE"].upper()}. Loading {MODEL_NAME}...') from src.model_run import RWKV_RNN model = RWKV_RNN( @@ -138,25 +139,25 @@ else: src_len = len(ctx) src_ctx = ctx.copy() -print("\nYour prompt has " + str(src_len) + " tokens.") +print("Your prompt has " + str(src_len) + " tokens.") print( - "\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n" + "\nNote: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n" ) -# time_slot = {} -# time_ref = time.time_ns() +time_slot = {} +time_ref = time.time_ns() -# def record_time(name): -# if name not in time_slot: -# time_slot[name] = 1e20 -# tt = (time.time_ns() - time_ref) / 1e9 -# if tt < time_slot[name]: -# time_slot[name] = tt +def record_time(name): + if name not in time_slot: + time_slot[name] = 1e20 + tt = (time.time_ns() - time_ref) / 1e9 + if tt < time_slot[name]: + time_slot[name] = tt for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): - # time_ref = time.time_ns() + print(("-" * 50) + '\n' + context, end="") - print(("-" * 50) + context, end="") + time_ref = time.time_ns() ctx = src_ctx.copy() model.clear() @@ -172,11 +173,9 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): else: model.load(init_state) - # record_time('model_pre') + record_time('preprocess') out_last = src_len for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): - # time_ref = time.time_ns() - x = ctx[: i + 1] x = x[-ctx_len:] @@ -184,14 +183,12 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): out = copy.deepcopy(init_state.out) else: out = model.forward(x) - # record_time('model_run') if DEBUG_DEBUG: print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) if TOKEN_MODE == "pile": out[0] = -999999999 # disable <|endoftext|> - time_ref = time.time_ns() ttt = tokenizer.sample_logits( out, x, @@ -211,9 +208,8 @@ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): print(char, end="", flush=True) out_last = i+1 - # record_time('model_sampling') - print() + record_time('total') # print(f'\n\n{time_slot}\n\n') - # print( - # f"\n--- preprocess {round((t_mid - t_begin) / (10 ** 9), 2)}s, generation {round((t_end - t_mid) / (10 ** 9), 2)}s", end = '' - # ) + print( + f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = '' + )