@ -18,14 +18,15 @@ DEBUG_TIME = False # True False - show trained time-coeffs
# CUDA Kernel
# CUDA Kernel
########################################################################################################
########################################################################################################
T_MAX = 4096 # increase this if your ctx_len is long
if os . environ [ ' RWKV_RUN_DEVICE ' ] == ' cuda ' :
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch . utils . cpp_extension import load
from torch . utils . cpp_extension import load
wkv_cuda = load ( name = " wkv " , sources = [ " cuda/wkv_op.cpp " , " cuda/wkv_cuda.cu " ] ,
wkv_cuda = load ( name = " wkv " , sources = [ " cuda/wkv_op.cpp " , " cuda/wkv_cuda.cu " ] ,
verbose = True , extra_cuda_cflags = [ ' --use_fast_math ' , ' --extra-device-vectorization ' , f ' -DTmax= { T_MAX } ' ] )
verbose = True , extra_cuda_cflags = [ ' --use_fast_math ' , ' --extra-device-vectorization ' , f ' -DTmax= { T_MAX } ' ] )
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
class WKV ( torch . autograd . Function ) :
class WKV ( torch . autograd . Function ) :
@staticmethod
@staticmethod
def forward ( ctx , B , T , C , w , u , k , v ) :
def forward ( ctx , B , T , C , w , u , k , v ) :
@ -59,7 +60,7 @@ if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
gw = torch . sum ( gw , dim = 0 )
gw = torch . sum ( gw , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . half ( ) , gu . half ( ) , gk . half ( ) , gv . half ( ) )
return ( None , None , None , gw . half ( ) , gu . half ( ) , gk . half ( ) , gv . half ( ) )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
class WKV ( torch . autograd . Function ) :
class WKV ( torch . autograd . Function ) :
@staticmethod
@staticmethod
def forward ( ctx , B , T , C , w , u , k , v ) :
def forward ( ctx , B , T , C , w , u , k , v ) :
@ -94,7 +95,7 @@ elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
gu = torch . sum ( gu , dim = 0 )
gu = torch . sum ( gu , dim = 0 )
return ( None , None , None , gw . bfloat16 ( ) , gu . bfloat16 ( ) , gk . bfloat16 ( ) , gv . bfloat16 ( ) )
return ( None , None , None , gw . bfloat16 ( ) , gu . bfloat16 ( ) , gk . bfloat16 ( ) , gv . bfloat16 ( ) )
def RUN_CUDA ( B , T , C , w , u , k , v ) :
def RUN_CUDA ( B , T , C , w , u , k , v ) :
return WKV . apply ( B , T , C , w . cuda ( ) , u . cuda ( ) , k . cuda ( ) , v . cuda ( ) )
return WKV . apply ( B , T , C , w . cuda ( ) , u . cuda ( ) , k . cuda ( ) , v . cuda ( ) )
############################################################################################################
############################################################################################################