@ -26,39 +26,74 @@ from torch.utils.cpp_extension import load
wkv_cuda = load ( name = " wkv " , sources = [ " cuda/wkv_op.cpp " , " cuda/wkv_cuda.cu " ] ,
wkv_cuda = load ( name = " wkv " , sources = [ " cuda/wkv_op.cpp " , " cuda/wkv_cuda.cu " ] ,
verbose = True , extra_cuda_cflags = [ ' --use_fast_math ' , ' --extra-device-vectorization ' , f ' -DTmax= { T_MAX } ' ] )
verbose = True , extra_cuda_cflags = [ ' --use_fast_math ' , ' --extra-device-vectorization ' , f ' -DTmax= { T_MAX } ' ] )
class WKV ( torch . autograd . Function ) :
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
@staticmethod
class WKV ( torch . autograd . Function ) :
def forward ( ctx , B , T , C , w , u , k , v ) :
@staticmethod
ctx . B = B
def forward ( ctx , B , T , C , w , u , k , v ) :
ctx . T = T
ctx . B = B
ctx . C = C
ctx . T = T
assert T < = T_MAX
ctx . C = C
assert B * C % min ( C , 1024 ) == 0
assert T < = T_MAX
w = - torch . exp ( w . float ( ) . contiguous ( ) )
assert B * C % min ( C , 1024 ) == 0
u = u . float ( ) . contiguous ( )
w = - torch . exp ( w . float ( ) . contiguous ( ) )
k = k . float ( ) . contiguous ( )
u = u . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
ctx . save_for_backward ( w , u , k , v )
v = v . float ( ) . contiguous ( )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
ctx . save_for_backward ( w , u , k , v )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
return y . half ( )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
return y . half ( )
@staticmethod
def backward ( ctx , gy ) :
@staticmethod
B = ctx . B
def backward ( ctx , gy ) :
T = ctx . T
B = ctx . B
C = ctx . C
T = ctx . T
assert T < = T_MAX
C = ctx . C
assert B * C % min ( C , 1024 ) == 0
assert T < = T_MAX
w , u , k , v = ctx . saved_tensors
assert B * C % min ( C , 1024 ) == 0
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
w , u , k , v = ctx . saved_tensors
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gw = torch . sum ( gw , dim = 0 )
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gu = torch . sum ( gu , dim = 0 )
gw = torch . sum ( gw , dim = 0 )
return ( None , None , None , gw . half ( ) , gu . half ( ) , gk . half ( ) , gv . half ( ) )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . half ( ) , gu . half ( ) , gk . half ( ) , gv . half ( ) )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
class WKV ( torch . autograd . Function ) :
@staticmethod
def forward ( ctx , B , T , C , w , u , k , v ) :
ctx . B = B
ctx . T = T
ctx . C = C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w = - torch . exp ( w . float ( ) . contiguous ( ) )
u = u . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
ctx . save_for_backward ( w , u , k , v )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
return y . bfloat16 ( )
@staticmethod
def backward ( ctx , gy ) :
B = ctx . B
T = ctx . T
C = ctx . C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w , u , k , v = ctx . saved_tensors
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gw = torch . sum ( gw , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . bfloat16 ( ) , gu . bfloat16 ( ) , gk . bfloat16 ( ) , gv . bfloat16 ( ) )
def RUN_CUDA ( B , T , C , w , u , k , v ) :
def RUN_CUDA ( B , T , C , w , u , k , v ) :
return WKV . apply ( B , T , C , w . cuda ( ) , u . cuda ( ) , k . cuda ( ) , v . cuda ( ) )
return WKV . apply ( B , T , C , w . cuda ( ) , u . cuda ( ) , k . cuda ( ) , v . cuda ( ) )
@ -336,7 +371,12 @@ class GPT(nn.Module):
k = self . head_k ( x ) [ : , : T , : ]
k = self . head_k ( x ) [ : , : T , : ]
c = ( q @ k . transpose ( - 2 , - 1 ) ) * ( 1.0 / RWKV_HEAD_QK_DIM )
c = ( q @ k . transpose ( - 2 , - 1 ) ) * ( 1.0 / RWKV_HEAD_QK_DIM )
c = c . masked_fill ( self . copy_mask [ : T , : T ] == 0 , 0 )
c = c . masked_fill ( self . copy_mask [ : T , : T ] == 0 , 0 )
c = c @ F . one_hot ( idx , num_classes = self . config . vocab_size ) . half ( )
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
c = c @ F . one_hot ( idx , num_classes = self . config . vocab_size ) . half ( )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
c = c @ F . one_hot ( idx , num_classes = self . config . vocab_size ) . bfloat16 ( )
x = self . head ( x ) + c
x = self . head ( x ) + c
else :
else :
x = self . head ( x )
x = self . head ( x )