torch jit (xx% faster inference)

main
BlinkDL 3 years ago
parent 819f2730b2
commit 75929cbbba

@ -9,14 +9,19 @@ from torch.nn import functional as F
import torch.nn as nn
from typing import List, Dict
# try:
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# except:
MyModule = nn.Module
def __nop(ob):
return ob
MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster!!
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
@ -26,7 +31,7 @@ RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################
class RWKV_RNN(nn.Module):
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
@ -113,7 +118,7 @@ class RWKV_RNN(nn.Module):
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction
def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw):
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
@ -134,7 +139,7 @@ class RWKV_RNN(nn.Module):
return r * kv
@MyFunction
def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)

Loading…
Cancel
Save