@ -3,16 +3,13 @@
########################################################################################################
import types
import copy
import torch
import math , os
import math , os , gc
from torch . nn import functional as F
import torch . nn as nn
def __nop ( ob ) :
return ob
MyModule = nn . Module
MyFunction = __nop
# MyModule = torch.jit.ScriptModule
@ -25,7 +22,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs
############################################################################################################
class RWKV_RNN ( MyModule ) : # this is running in FP32 at this moment
class RWKV_RNN ( MyModule ) :
def __init__ ( self , MODEL_NAME , RUN_DEVICE , model_type , n_layer , n_embd , ctx_len ) :
super ( ) . __init__ ( )
@ -35,20 +32,50 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
self . n_embd = n_embd
self . ctx_len = ctx_len
self . w = types . SimpleNamespace ( )
w = torch . load ( MODEL_NAME + ' .pth ' , map_location = ' cpu ' )
w = torch . load ( MODEL_NAME + ' .pth ' , map_location = torch . device ( RUN_DEVICE ) )
for x in w . keys ( ) :
w [ x ] = w [ x ] . float ( )
# refine weights and send to correct device
keys = list ( w . keys ( ) )
if ' pos_emb_x ' in keys :
w [ ' pos_emb ' ] = ( w [ ' pos_emb_x ' ] + w [ ' pos_emb_y ' ] ) . reshape ( ctx_len + 1 , - 1 ) [ : - 1 , : ]
keys = list ( w . keys ( ) )
print_need_newline = False
for x in keys :
if ' .time_ ' in x :
w [ x ] = w [ x ] . squeeze ( )
if DEBUG_TIME :
print ( x , w [ x ] . numpy ( ) )
if ' .time_decay ' in x :
w [ x ] = w [ x ] . float ( )
w [ x ] = - torch . exp ( w [ x ] )
if ' pos_emb_x ' in x :
self . w . pos_emb = ( w [ ' pos_emb_x ' ] + w [ ' pos_emb_y ' ] ) . reshape ( ctx_len + 1 , - 1 ) [ : - 1 , : ]
if DEBUG_TIME and ' .time_ ' in x :
print ( x , w [ x ] . squeeze ( ) . cpu ( ) . numpy ( ) )
elif ' .time_first ' in x :
w [ x ] = w [ x ] . float ( )
else :
if os . environ [ " RWKV_FLOAT_MODE " ] == " fp32 " :
w [ x ] = w [ x ] . float ( )
elif os . environ [ " RWKV_FLOAT_MODE " ] == " bf16 " :
w [ x ] = w [ x ] . bfloat16 ( )
w [ x ] . requires_grad = False
if RUN_DEVICE == ' cuda ' and x != ' emb.weight ' :
w [ x ] = w [ x ] . cuda ( )
if ( ' blocks. ' not in x ) or ( ' blocks.0. ' in x ) :
if print_need_newline :
print ( ' \n ' , end = ' ' )
print_need_newline = False
print ( x . ljust ( 40 ) , str ( w [ x ] . dtype ) . replace ( ' torch. ' , ' ' ) . ljust ( 10 ) , w [ x ] . device )
else :
print_need_newline = True
print ( ' . ' , end = ' ' , flush = True )
# store weights in self.w
keys = list ( w . keys ( ) )
self . w = types . SimpleNamespace ( )
for x in keys :
xx = x . split ( ' . ' )
here = self . w
for i in range ( len ( xx ) ) :
@ -67,41 +94,26 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
setattr ( here , xx [ i ] , types . SimpleNamespace ( ) )
here = getattr ( here , xx [ i ] )
self . clear ( )
self . eval ( )
def clear ( self ) :
self . xx = { }
self . aa = { }
self . bb = { }
self . pp = { }
self . hk = None
def save ( self , target ) :
target . xx = copy . deepcopy ( self . xx )
target . aa = copy . deepcopy ( self . aa )
target . bb = copy . deepcopy ( self . bb )
target . pp = copy . deepcopy ( self . pp )
target . hk = copy . deepcopy ( self . hk )
def load ( self , target ) :
self . xx = copy . deepcopy ( target . xx )
self . aa = copy . deepcopy ( target . aa )
self . bb = copy . deepcopy ( target . bb )
self . pp = copy . deepcopy ( target . pp )
self . hk = copy . deepcopy ( target . hk )
gc . collect ( )
torch . cuda . empty_cache ( )
@MyFunction
def LN ( self , xx , w ) :
return F . layer_norm ( xx , ( self . n_embd , ) , weight = w . weight , bias = w . bias )
def LN ( self , x , w ) :
return F . layer_norm ( x , ( self . n_embd , ) , weight = w . weight , bias = w . bias )
# state: ffn_xx att_xx att_aa att_bb att_pp
@MyFunction
def FF ( self , xx , w , name ) :
if name not in self . xx :
self . xx [ name ] = torch . zeros ( self . n_embd , device = self . RUN_DEVICE )
xk = xx * w . time_mix_k + self . xx [ name ] * ( 1 - w . time_mix_k )
xr = xx * w . time_mix_r + self . xx [ name ] * ( 1 - w . time_mix_r )
self . xx [ name ] = xx
def FF ( self , x , w , state , i ) :
if os . environ [ " RWKV_FLOAT_MODE " ] == " bf16 " :
xk = x * w . time_mix_k + state [ 5 * i + 0 ] . bfloat16 ( ) * ( 1 - w . time_mix_k )
xr = x * w . time_mix_r + state [ 5 * i + 0 ] . bfloat16 ( ) * ( 1 - w . time_mix_r )
state [ 5 * i + 0 ] = x . float ( )
else :
xk = x * w . time_mix_k + state [ 5 * i + 0 ] * ( 1 - w . time_mix_k )
xr = x * w . time_mix_r + state [ 5 * i + 0 ] * ( 1 - w . time_mix_r )
state [ 5 * i + 0 ] = x
r = torch . sigmoid ( w . receptance . weight @ xr )
k = torch . square ( torch . relu ( w . key . weight @ xk ) )
@ -110,90 +122,92 @@ class RWKV_RNN(MyModule): # this is running in FP32 at this moment
return r * kv
@MyFunction
def SA ( self , x x, w , name ) :
if name not in self . xx :
self . xx [ name ] = torch . zeros ( self . n_embd , device = self . RUN_DEVICE )
self . aa [ name ] = torch . zeros ( self . n_embd , device = self . RUN_DEVICE )
self . bb [ name ] = torch . zeros ( self . n_embd , device = self . RUN_DEVICE )
self . pp [ name ] = torch . zeros ( self . n_embd , device = self . RUN_DEVICE ) - 1e30
xk = x x * w . time_mix_k + self . xx [ name ] * ( 1 - w . time_mix_k )
xv = x x * w . time_mix_v + self . xx [ name ] * ( 1 - w . time_mix_v )
xr = x x * w . time_mix_r + self . xx [ name ] * ( 1 - w . time_mix_r )
self . xx [ name ] = x x
def SA ( self , x , w , state , i ) :
if os. environ [ " RWKV_FLOAT_MODE " ] == " bf16 " :
xk = x * w . time_mix_k + state [ 5 * i + 1 ] . bfloat16 ( ) * ( 1 - w . time_mix_k )
xv = x * w . time_mix_v + state [ 5 * i + 1 ] . bfloat16 ( ) * ( 1 - w . time_mix_v )
xr = x * w . time_mix_r + state [ 5 * i + 1 ] . bfloat16 ( ) * ( 1 - w . time_mix_r )
state [ 5 * i + 1 ] = x . float ( )
else :
xk = x * w . time_mix_k + state [ 5 * i + 1 ] * ( 1 - w . time_mix_k )
xv = x * w . time_mix_v + state [ 5 * i + 1 ] * ( 1 - w . time_mix_v )
xr = x * w . time_mix_r + state [ 5 * i + 1 ] * ( 1 - w . time_mix_r )
state [ 5 * i + 1 ] = x
r = torch . sigmoid ( w . receptance . weight @ xr )
k = w . key . weight @ xk
v = w . value . weight @ xv
pp = self . pp [ name ]
aa = self . aa [ name ]
bb = self . bb [ name ]
ww = w . time_first + k
p = torch . maximum ( pp , ww )
e1 = torch . exp ( pp - p )
e2 = torch . exp ( ww - p )
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w . time_decay
p = torch . maximum ( ww , k )
e1 = torch . exp ( ww - p )
e2 = torch . exp ( k - p )
self . aa [ name ] = e1 * aa + e2 * v
self . bb [ name ] = e1 * bb + e2
self . pp [ name ] = p
rwkv = r * a / b
if os . environ [ " RWKV_FLOAT_MODE " ] == " bf16 " :
kk = k . float ( )
vv = v . float ( )
aa = state [ 5 * i + 2 ]
bb = state [ 5 * i + 3 ]
pp = state [ 5 * i + 4 ]
ww = w . time_first + kk
p = torch . maximum ( pp , ww )
e1 = torch . exp ( pp - p )
e2 = torch . exp ( ww - p )
a = e1 * aa + e2 * vv
b = e1 * bb + e2
ww = pp + w . time_decay
p = torch . maximum ( ww , kk )
e1 = torch . exp ( ww - p )
e2 = torch . exp ( kk - p )
state [ 5 * i + 2 ] = e1 * aa + e2 * vv
state [ 5 * i + 3 ] = e1 * bb + e2
state [ 5 * i + 4 ] = p
rwkv = r * ( a / b ) . bfloat16 ( )
else :
aa = state [ 5 * i + 2 ]
bb = state [ 5 * i + 3 ]
pp = state [ 5 * i + 4 ]
ww = w . time_first + k
p = torch . maximum ( pp , ww )
e1 = torch . exp ( pp - p )
e2 = torch . exp ( ww - p )
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + w . time_decay
p = torch . maximum ( ww , k )
e1 = torch . exp ( ww - p )
e2 = torch . exp ( k - p )
state [ 5 * i + 2 ] = e1 * aa + e2 * v
state [ 5 * i + 3 ] = e1 * bb + e2
state [ 5 * i + 4 ] = p
rwkv = r * a / b
return w . output . weight @ rwkv
def forward ( self , ctx , preprocess_only = False ) :
def forward ( self , ctx , state, preprocess_only = False ) :
with torch . no_grad ( ) :
w = self . w
x = w . emb . weight [ ctx [ - 1 ] ]
if self . RUN_DEVICE == ' cuda ' :
x = x . cuda ( )
try :
pos_emb = w . pos_emb [ len ( ctx ) - 1 ]
x = x + pos_emb
except :
pass
if state == None :
state = torch . zeros ( self . n_layer * 5 , self . n_embd , device = self . RUN_DEVICE )
for i in range ( self . n_layer ) :
state [ 5 * i + 4 ] - = 1e30
for i in range ( self . n_layer ) :
if i == 0 :
x = self . LN ( x , w . blocks [ i ] . ln0 )
if i == 0 and self . model_type == ' RWKV-ffnPre ' :
x = x + self . FF ( self . LN ( x , w . blocks [ i ] . ln1 ) , w . blocks [ i ] . ffnPre , f ' ffnPre. { i } ' )
else :
x = x + self . SA ( self . LN ( x , w . blocks [ i ] . ln1 ) , w . blocks [ i ] . att , f ' att. { i } ' )
x = x + self . FF ( self . LN ( x , w . blocks [ i ] . ln2 ) , w . blocks [ i ] . ffn , f ' ffn. { i } ' )
x = self . LN ( x , w . ln_out )
if RWKV_HEAD_QK_DIM > 0 :
if self . hk == None :
self . hk = ( w . head_k . weight @ x ) . unsqueeze ( 0 )
else :
self . hk = torch . cat (
[ self . hk , ( w . head_k . weight @ x ) . unsqueeze ( 0 ) ] , dim = 0 )
if self . hk . shape [ 0 ] > self . ctx_len :
self . hk = self . hk [ - self . ctx_len : , : ]
x = x + self . SA ( self . LN ( x , w . blocks [ i ] . ln1 ) , w . blocks [ i ] . att , state , i )
x = x + self . FF ( self . LN ( x , w . blocks [ i ] . ln2 ) , w . blocks [ i ] . ffn , state , i )
if preprocess_only :
return Non e
if preprocess_only :
return state
q = w . head_q . weight @ x
x = w . head . weight @ x
x = x
c = ( self . hk @ q ) / RWKV_HEAD_QK_DIM
for i in range ( len ( c ) ) :
x [ ctx [ i ] ] + = c [ i ]
else :
if preprocess_only :
return None
x = w . head . weight @ x
x = x
x = self . LN ( x , w . ln_out )
x = w . head . weight @ x
return x
return x . float ( ) , state