torch jit (xx% faster inference)

main
BlinkDL 3 years ago
parent 819f2730b2
commit 75929cbbba

@ -9,14 +9,19 @@ from torch.nn import functional as F
import torch.nn as nn import torch.nn as nn
from typing import List, Dict from typing import List, Dict
# try: MyModule = nn.Module
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# except:
def __nop(ob): def __nop(ob):
return ob return ob
MyFunction = __nop MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster!!
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0 RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
@ -26,7 +31,7 @@ RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################ ############################################################################################################
class RWKV_RNN(nn.Module): class RWKV_RNN(MyModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
@ -113,7 +118,7 @@ class RWKV_RNN(nn.Module):
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction @MyFunction
def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw): def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
@ -134,7 +139,7 @@ class RWKV_RNN(nn.Module):
return r * kv return r * kv
@MyFunction @MyFunction
def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16": if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)

Loading…
Cancel
Save