@ -51,6 +51,7 @@ class RWKV_TimeMix(nn.Module):
v = self . value ( x )
r = self . receptance ( x )
k = torch . clamp ( k , max = 30 ) # clamp crazy values
k = torch . exp ( k )
sum_k = torch . cumsum ( k , dim = 1 )
@ -261,20 +262,6 @@ class MHA_pro(nn.Module):
# The GPT Model with our blocks
########################################################################################################
class LabelSmoothingCrossEntropy ( nn . Module ) : # can avoid nan loss
def __init__ ( self , smoothing = 0.0 ) :
super ( ) . __init__ ( )
self . confidence = 1.0 - smoothing
self . smoothing = smoothing
def forward ( self , pred , target ) :
pred = pred . log_softmax ( dim = - 1 )
with torch . no_grad ( ) :
true_dist = torch . zeros_like ( pred )
true_dist . fill_ ( self . smoothing / ( pred . size ( - 1 ) - 1 ) )
true_dist . scatter_ ( 1 , target . data . unsqueeze ( 1 ) , self . confidence )
return torch . mean ( torch . sum ( - true_dist * pred , dim = - 1 ) )
class RMSNorm ( nn . Module ) :
def __init__ ( self , d ) :
super ( ) . __init__ ( )
@ -379,7 +366,7 @@ class GPT(nn.Module):
curve = curve - torch . mean ( curve ) + 1 # normalize mean to 1
mix_strength = 1 - 1.2 * h / ( self . config . n_head - 1 ) # mix_strength from 1 to -0.2
ww [ k ] [ h ] = ( 1 - mix_strength ) + curve * mix_strength
# special tweak because of time_shift
# special tweak s because of time_shift
ww [ k ] [ h ] [ self . config . ctx_len - 3 ] = ( ww [ k ] [ h ] [ self . config . ctx_len - 2 ] * 2 + 1 ) / 3
ww [ k ] [ h ] [ self . config . ctx_len - 2 ] = ( ww [ k ] [ h ] [ self . config . ctx_len - 2 ] + 1 ) / 2
ww [ k ] [ h ] [ self . config . ctx_len - 1 ] = 1
@ -450,6 +437,6 @@ class GPT(nn.Module):
loss = None
if targets is not None :
loss = LabelSmoothingCrossEntropy( smoothing = 5e-5 ) ( x . view ( - 1 , x . size ( - 1 ) ) , targets . view ( - 1 ) )
loss = F. cross_entropy ( x . view ( - 1 , x . size ( - 1 ) ) , targets . view ( - 1 ) )
return x , loss