From dc26998708a12660f5648ae6664718b2cae56312 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sun, 15 Jan 2023 14:41:49 +0000 Subject: [PATCH] torch jit --- RWKV-v4neo/run.py | 8 +++++--- RWKV-v4neo/src/model_run.py | 9 +++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index d53d819..240fe1c 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -21,10 +21,12 @@ args = types.SimpleNamespace() # Do this first: pip install torchdynamo ######################################################################################################## -args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda' -args.FLOAT_MODE = "fp32" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) +args.RUN_DEVICE = "cpu" # 'cpu' (already fast) // 'cuda' +args.FLOAT_MODE = "fp32" # fp32 (good for CPU) // fp16 (good for GPU, does not work for CPU) // bf16 (less accurate, but works for CPU) + # if args.RUN_DEVICE == "cuda": # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output +os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! TOKEN_MODE = "pile" WORD_NAME = [ @@ -85,7 +87,7 @@ context = "\nIn a shocking finding, scientist discovered a herd of dragons livin # context = "\n深圳是" # test Chinese # context = "\n東京は" # test Japanese -###### A good prompt for chatbot ###### +# ###### A good prompt for chatbot ###### # context = ''' # The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins. diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index 479db5e..2516e50 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -18,12 +18,13 @@ MyFunction = __nop # import torchdynamo # MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output -# try torch jit --> faster!! -MyModule = torch.jit.ScriptModule -MyFunction = torch.jit.script_method +# try torch jit --> faster for fp32, slower for fp16 (why?) +if os.environ["RWKV_JIT_ON"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') DEBUG_TIME = False # True False - show trained time-coeffs