|
|
|
@ -17,39 +17,41 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in
|
|
|
|
for m in module.modules():
|
|
|
|
for m in module.modules():
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
name = '[unknown weight]'
|
|
|
|
name = '[unknown weight]'
|
|
|
|
for name, parameter in module.named_parameters(): # find the name of the weight
|
|
|
|
for name, parameter in module.named_parameters(): # find the name of the weight
|
|
|
|
if id(m.weight) == id(parameter):
|
|
|
|
if id(m.weight) == id(parameter):
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
shape = m.weight.data.shape
|
|
|
|
shape = m.weight.data.shape
|
|
|
|
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
|
|
|
|
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
|
|
|
|
scale = 1.0 # extra scale for gain
|
|
|
|
scale = 1.0 # extra scale for gain
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
if m.bias is not None:
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
|
|
|
m.bias.data.zero_()
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
|
|
scale = config.rwkv_emb_scale
|
|
|
|
scale = config.rwkv_emb_scale
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
|
|
gain = math.sqrt(max(shape[0], shape[1]))
|
|
|
|
gain = math.sqrt(max(shape[0], shape[1]))
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
|
|
scale = config.rwkv_emb_scale
|
|
|
|
scale = config.rwkv_emb_scale
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(m, 'scale_init'):
|
|
|
|
if hasattr(m, 'scale_init'):
|
|
|
|
scale = m.scale_init
|
|
|
|
scale = m.scale_init
|
|
|
|
|
|
|
|
|
|
|
|
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
|
|
|
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
|
|
|
|
|
|
|
|
|
|
|
gain *= scale
|
|
|
|
gain *= scale
|
|
|
|
if gain > 0:
|
|
|
|
if gain == 0:
|
|
|
|
nn.init.orthogonal_(m.weight, gain=gain)
|
|
|
|
nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices
|
|
|
|
else:
|
|
|
|
elif gain > 0:
|
|
|
|
nn.init.normal_(m.weight, mean=0, std=-gain)
|
|
|
|
nn.init.orthogonal_(m.weight, gain=gain)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
nn.init.normal_(m.weight, mean=0, std=-gain)
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_TimeMix(nn.Module):
|
|
|
|
class RWKV_TimeMix(nn.Module):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
@ -95,7 +97,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
self.key.scale_init = 0
|
|
|
|
self.key.scale_init = 0
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.output.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
|
|
|
|
self.output.scale_init = 0
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
@ -145,7 +147,7 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
self.weight.scale_init = 1 / pow(1+layer_id, config.rwkv_layer_decay) # reduce initial weight in higher layers
|
|
|
|
self.weight.scale_init = 0
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
|