From 75929cbbbae38114dc96607aafad5741507d07f4 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sun, 15 Jan 2023 14:29:06 +0000 Subject: [PATCH] torch jit (xx% faster inference) --- RWKV-v4neo/src/model_run.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index 0e0291c..479db5e 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -9,14 +9,19 @@ from torch.nn import functional as F import torch.nn as nn from typing import List, Dict -# try: -# import torchdynamo -# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output -# except: +MyModule = nn.Module def __nop(ob): return ob MyFunction = __nop +# # try torchdynamo +# import torchdynamo +# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output + +# try torch jit --> faster!! +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method + RWKV_HEAD_QK_DIM = 0 print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') @@ -26,7 +31,7 @@ RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer ############################################################################################################ -class RWKV_RNN(nn.Module): +class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() @@ -113,7 +118,7 @@ class RWKV_RNN(nn.Module): # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp @MyFunction - def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw): + def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): if self.FLOAT_MODE == "bf16": xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) @@ -134,7 +139,7 @@ class RWKV_RNN(nn.Module): return r * kv @MyFunction - def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): + def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): if self.FLOAT_MODE == "bf16": xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)