diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index d53d819..240fe1c 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -21,10 +21,12 @@ args = types.SimpleNamespace() # Do this first: pip install torchdynamo ######################################################################################################## -args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda' -args.FLOAT_MODE = "fp32" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) +args.RUN_DEVICE = "cpu" # 'cpu' (already fast) // 'cuda' +args.FLOAT_MODE = "fp32" # fp32 (good for CPU) // fp16 (good for GPU, does not work for CPU) // bf16 (less accurate, but works for CPU) + # if args.RUN_DEVICE == "cuda": # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output +os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! TOKEN_MODE = "pile" WORD_NAME = [ @@ -85,7 +87,7 @@ context = "\nIn a shocking finding, scientist discovered a herd of dragons livin # context = "\n深圳是" # test Chinese # context = "\n東京は" # test Japanese -###### A good prompt for chatbot ###### +# ###### A good prompt for chatbot ###### # context = ''' # The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins. diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index 479db5e..2516e50 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -18,12 +18,13 @@ MyFunction = __nop # import torchdynamo # MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output -# try torch jit --> faster!! -MyModule = torch.jit.ScriptModule -MyFunction = torch.jit.script_method +# try torch jit --> faster for fp32, slower for fp16 (why?) +if os.environ["RWKV_JIT_ON"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') DEBUG_TIME = False # True False - show trained time-coeffs