@ -34,25 +34,25 @@ if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
ctx . C = C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
if os . environ [ ' RWKV_FLOAT_MODE ' ] != ' fp32 ' :
w = - torch . exp ( w . float ( ) . contiguous ( ) )
u = u . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
else :
if ' 32 ' in os . environ [ ' RWKV_FLOAT_MODE ' ] :
w = - torch . exp ( w . contiguous ( ) )
u = u . contiguous ( )
k = k . contiguous ( )
v = v . contiguous ( )
else :
w = - torch . exp ( w . float ( ) . contiguous ( ) )
u = u . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
ctx . save_for_backward ( w , u , k , v )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
if ' 32 ' in os . environ [ ' RWKV_FLOAT_MODE ' ] :
return y
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
return y . half ( )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
return y . bfloat16 ( )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp32 ' :
return y
@staticmethod
def backward ( ctx , gy ) :
@ -62,22 +62,22 @@ if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w , u , k , v = ctx . saved_tensors
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
if os . environ [ ' RWKV_FLOAT_MODE ' ] != ' fp32 ' :
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
else :
gw = torch . zeros ( ( B , C ) , device = ' cuda ' ) . contiguous ( )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' ) . contiguous ( )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' ) . contiguous ( )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' ) . contiguous ( )
if ' 32 ' in os . environ [ ' RWKV_FLOAT_MODE ' ] :
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . contiguous ( ) , gw , gu , gk , gv )
else :
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gw = torch . sum ( gw , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
if ' 32 ' in os . environ [ ' RWKV_FLOAT_MODE ' ] :
return ( None , None , None , gw , gu , gk , gv )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
return ( None , None , None , gw . half ( ) , gu . half ( ) , gk . half ( ) , gv . half ( ) )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
return ( None , None , None , gw . bfloat16 ( ) , gu . bfloat16 ( ) , gk . bfloat16 ( ) , gv . bfloat16 ( ) )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp32 ' :
return ( None , None , None , gw , gu , gk , gv )
def RUN_CUDA ( B , T , C , w , u , k , v ) :
return WKV . apply ( B , T , C , w . cuda ( ) , u . cuda ( ) , k . cuda ( ) , v . cuda ( ) )
@ -222,7 +222,13 @@ class RWKV_GPT(nn.Module):
c = ( q @ k . transpose ( - 2 , - 1 ) ) * ( 1.0 / RWKV_HEAD_QK_DIM )
c = c . masked_fill ( self . copy_mask [ : T , : T ] == 0 , 0 )
c = c @ F . one_hot ( idx , num_classes = RWKV_CFG . vocab_size ) . float ( )
if ' 32 ' in os . environ [ ' RWKV_FLOAT_MODE ' ] :
c = c @ F . one_hot ( idx , num_classes = RWKV_CFG . vocab_size )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
c = c @ F . one_hot ( idx , num_classes = RWKV_CFG . vocab_size ) . half ( )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
c = c @ F . one_hot ( idx , num_classes = RWKV_CFG . vocab_size ) . bfloat16 ( )
x = self . head ( x ) + c
else :
x = self . head ( x )