@ -18,84 +18,85 @@ DEBUG_TIME = False # True False - show trained time-coeffs
# CUDA Kernel
########################################################################################################
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch . utils . cpp_extension import load
wkv_cuda = load ( name = " wkv " , sources = [ " cuda/wkv_op.cpp " , " cuda/wkv_cuda.cu " ] ,
verbose = True , extra_cuda_cflags = [ ' --use_fast_math ' , ' --extra-device-vectorization ' , f ' -DTmax= { T_MAX } ' ] )
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
class WKV ( torch . autograd . Function ) :
@staticmethod
def forward ( ctx , B , T , C , w , u , k , v ) :
ctx . B = B
ctx . T = T
ctx . C = C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w = - torch . exp ( w . float ( ) . contiguous ( ) )
u = u . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
ctx . save_for_backward ( w , u , k , v )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
return y . half ( )
@staticmethod
def backward ( ctx , gy ) :
B = ctx . B
T = ctx . T
C = ctx . C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w , u , k , v = ctx . saved_tensors
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gw = torch . sum ( gw , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . half ( ) , gu . half ( ) , gk . half ( ) , gv . half ( ) )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
class WKV ( torch . autograd . Function ) :
@staticmethod
def forward ( ctx , B , T , C , w , u , k , v ) :
ctx . B = B
ctx . T = T
ctx . C = C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w = - torch . exp ( w . float ( ) . contiguous ( ) )
u = u . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
ctx . save_for_backward ( w , u , k , v )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
return y . bfloat16 ( )
@staticmethod
def backward ( ctx , gy ) :
B = ctx . B
T = ctx . T
C = ctx . C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w , u , k , v = ctx . saved_tensors
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gw = torch . sum ( gw , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . bfloat16 ( ) , gu . bfloat16 ( ) , gk . bfloat16 ( ) , gv . bfloat16 ( ) )
def RUN_CUDA ( B , T , C , w , u , k , v ) :
return WKV . apply ( B , T , C , w . cuda ( ) , u . cuda ( ) , k . cuda ( ) , v . cuda ( ) )
if os . environ [ ' RWKV_RUN_DEVICE ' ] == ' cuda ' :
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch . utils . cpp_extension import load
wkv_cuda = load ( name = " wkv " , sources = [ " cuda/wkv_op.cpp " , " cuda/wkv_cuda.cu " ] ,
verbose = True , extra_cuda_cflags = [ ' --use_fast_math ' , ' --extra-device-vectorization ' , f ' -DTmax= { T_MAX } ' ] )
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
class WKV ( torch . autograd . Function ) :
@staticmethod
def forward ( ctx , B , T , C , w , u , k , v ) :
ctx . B = B
ctx . T = T
ctx . C = C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w = - torch . exp ( w . float ( ) . contiguous ( ) )
u = u . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
ctx . save_for_backward ( w , u , k , v )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
return y . half ( )
@staticmethod
def backward ( ctx , gy ) :
B = ctx . B
T = ctx . T
C = ctx . C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w , u , k , v = ctx . saved_tensors
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gw = torch . sum ( gw , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . half ( ) , gu . half ( ) , gk . half ( ) , gv . half ( ) )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
class WKV ( torch . autograd . Function ) :
@staticmethod
def forward ( ctx , B , T , C , w , u , k , v ) :
ctx . B = B
ctx . T = T
ctx . C = C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w = - torch . exp ( w . float ( ) . contiguous ( ) )
u = u . float ( ) . contiguous ( )
k = k . float ( ) . contiguous ( )
v = v . float ( ) . contiguous ( )
ctx . save_for_backward ( w , u , k , v )
y = torch . empty ( ( B , T , C ) , device = ' cuda ' , memory_format = torch . contiguous_format )
wkv_cuda . forward ( B , T , C , w , u , k , v , y )
return y . bfloat16 ( )
@staticmethod
def backward ( ctx , gy ) :
B = ctx . B
T = ctx . T
C = ctx . C
assert T < = T_MAX
assert B * C % min ( C , 1024 ) == 0
w , u , k , v = ctx . saved_tensors
gw = torch . zeros ( ( B , C ) , device = ' cuda ' )
gu = torch . zeros ( ( B , C ) , device = ' cuda ' )
gk = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
gv = torch . zeros ( ( B , T , C ) , device = ' cuda ' )
wkv_cuda . backward ( B , T , C , w , u , k , v , gy . float ( ) . contiguous ( ) , gw , gu , gk , gv )
gw = torch . sum ( gw , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . bfloat16 ( ) , gu . bfloat16 ( ) , gk . bfloat16 ( ) , gv . bfloat16 ( ) )
def RUN_CUDA ( B , T , C , w , u , k , v ) :
return WKV . apply ( B , T , C , w . cuda ( ) , u . cuda ( ) , k . cuda ( ) , v . cuda ( ) )
############################################################################################################