|
|
|
@ -9,14 +9,19 @@ from torch.nn import functional as F
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
from typing import List, Dict
|
|
|
|
from typing import List, Dict
|
|
|
|
|
|
|
|
|
|
|
|
# try:
|
|
|
|
MyModule = nn.Module
|
|
|
|
# import torchdynamo
|
|
|
|
|
|
|
|
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
|
|
|
|
|
|
|
|
# except:
|
|
|
|
|
|
|
|
def __nop(ob):
|
|
|
|
def __nop(ob):
|
|
|
|
return ob
|
|
|
|
return ob
|
|
|
|
MyFunction = __nop
|
|
|
|
MyFunction = __nop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# # try torchdynamo
|
|
|
|
|
|
|
|
# import torchdynamo
|
|
|
|
|
|
|
|
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# try torch jit --> faster!!
|
|
|
|
|
|
|
|
MyModule = torch.jit.ScriptModule
|
|
|
|
|
|
|
|
MyFunction = torch.jit.script_method
|
|
|
|
|
|
|
|
|
|
|
|
RWKV_HEAD_QK_DIM = 0
|
|
|
|
RWKV_HEAD_QK_DIM = 0
|
|
|
|
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
|
|
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
|
|
|
|
|
|
|
|
|
|
@ -26,7 +31,7 @@ RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
|
|
|
|
|
|
|
|
|
|
|
|
############################################################################################################
|
|
|
|
############################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_RNN(nn.Module):
|
|
|
|
class RWKV_RNN(MyModule):
|
|
|
|
def __init__(self, args):
|
|
|
|
def __init__(self, args):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
@ -113,7 +118,7 @@ class RWKV_RNN(nn.Module):
|
|
|
|
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
|
|
|
|
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
|
|
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
@MyFunction
|
|
|
|
def FF(self, x, state, i, time_mix_k, time_mix_r, kw, vw, rw):
|
|
|
|
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
|
|
|
|
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
|
|
|
|
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
|
|
|
|
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
|
|
|
|
@ -134,7 +139,7 @@ class RWKV_RNN(nn.Module):
|
|
|
|
return r * kv
|
|
|
|
return r * kv
|
|
|
|
|
|
|
|
|
|
|
|
@MyFunction
|
|
|
|
@MyFunction
|
|
|
|
def SA(self, x, state, i, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
|
|
|
|
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
if self.FLOAT_MODE == "bf16":
|
|
|
|
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
|
|
|
|
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
|
|
|
|
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
|
|
|
|
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
|
|
|
|
|