|
|
|
@ -12,17 +12,74 @@ logger = logging.getLogger(__name__)
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# fancy initialization of lin & emb layers, for faster convergence
|
|
|
|
|
|
|
|
# note it will change ALL lin & emb layers in the module (including token emb & final projection)
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
def RWKV_Init(module, config):
|
|
|
|
|
|
|
|
for m in module.modules():
|
|
|
|
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
name = '[unknown weight]'
|
|
|
|
|
|
|
|
for name, parameter in module.named_parameters(): # find the name of the weight
|
|
|
|
|
|
|
|
if id(m.weight) == id(parameter):
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shape = m.weight.data.shape
|
|
|
|
|
|
|
|
gain = 1.0 # positive: gain for orthogonal, negative: std for normal
|
|
|
|
|
|
|
|
scale = 1.0 # extra scale for gain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
|
|
|
|
|
|
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
|
|
|
|
|
|
gain = math.sqrt(max(shape[0], shape[1]))
|
|
|
|
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
|
|
|
|
|
|
scale = 0.4 # 0.4 is a safe choice, 0.8 is better for chinese
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(m, 'scale_init'):
|
|
|
|
|
|
|
|
scale = m.scale_init
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gain *= scale
|
|
|
|
|
|
|
|
if gain > 0:
|
|
|
|
|
|
|
|
nn.init.orthogonal_(m.weight, gain=gain)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
nn.init.normal_(m.weight, mean=0, std=-gain)
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_TimeMix(nn.Module):
|
|
|
|
class RWKV_TimeMix(nn.Module):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
assert config.n_attn % config.n_head == 0
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.n_head = config.n_head
|
|
|
|
self.n_head = config.n_head
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
self.head_size = config.n_attn // config.n_head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): # build initial time_w curves for better convergence
|
|
|
|
|
|
|
|
ww = torch.zeros(config.n_head, config.ctx_len)
|
|
|
|
|
|
|
|
curve = torch.tensor([0.9 ** (config.ctx_len - 1 - i) for i in range(config.ctx_len)])
|
|
|
|
|
|
|
|
curve = curve * 2 + 0.7
|
|
|
|
|
|
|
|
for h in range(config.n_head):
|
|
|
|
|
|
|
|
if config.n_head > 1:
|
|
|
|
|
|
|
|
mix_strength = 1 - 1.2 * h / (config.n_head - 1) # mix_strength from 1 to -0.2
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
mix_strength = 0.5
|
|
|
|
|
|
|
|
ww[h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
|
|
|
|
# special tweaks because of time_shift
|
|
|
|
|
|
|
|
ww[h][config.ctx_len - 3] = (ww[h][config.ctx_len - 3] * 2 + 1) / 3
|
|
|
|
|
|
|
|
ww[h][config.ctx_len - 2] = (ww[h][config.ctx_len - 2] * 1 + 2) / 3
|
|
|
|
|
|
|
|
ww[h][config.ctx_len - 1] = 1
|
|
|
|
|
|
|
|
# print(h, mix_strength, ww[h])
|
|
|
|
|
|
|
|
self.time_w = nn.Parameter(ww)
|
|
|
|
|
|
|
|
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
|
|
|
|
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
|
|
|
|
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
|
|
|
|
@ -30,11 +87,15 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.output = nn.Linear(config.n_attn, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.key.scale_init = 0
|
|
|
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
|
|
|
|
self.output.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
@ -57,7 +118,7 @@ class RWKV_TimeMix(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
kv = (k * v).view(B, T, self.n_head, self.head_size)
|
|
|
|
kv = (k * v).view(B, T, self.n_head, self.head_size)
|
|
|
|
|
|
|
|
|
|
|
|
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, C)
|
|
|
|
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)
|
|
|
|
|
|
|
|
|
|
|
|
rwkv = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
rwkv = torch.sigmoid(r) * wkv / sum_k
|
|
|
|
|
|
|
|
|
|
|
|
@ -69,12 +130,15 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
|
|
hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R
|
|
|
|
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of R
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.receptance.scale_init = 0
|
|
|
|
|
|
|
|
self.weight.scale_init = 1 / pow(1+layer_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
|
|
|
|
|
|
|
|
@ -125,24 +189,24 @@ class MHA_rotary(nn.Module):
|
|
|
|
def __init__(self, config, layer_id, time_shift = False):
|
|
|
|
def __init__(self, config, layer_id, time_shift = False):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.layer_id = layer_id
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
assert config.n_attn % config.n_head == 0
|
|
|
|
self.n_head = config.n_head
|
|
|
|
self.n_head = config.n_head
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
self.head_size = config.n_attn // config.n_head
|
|
|
|
|
|
|
|
|
|
|
|
if time_shift:
|
|
|
|
if time_shift:
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.output = nn.Linear(config.n_attn, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
@ -166,7 +230,7 @@ class MHA_rotary(nn.Module):
|
|
|
|
att = F.softmax(att, dim = -1) # softmax
|
|
|
|
att = F.softmax(att, dim = -1) # softmax
|
|
|
|
|
|
|
|
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.output(x)
|
|
|
|
x = self.output(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
@ -179,7 +243,7 @@ class GeGLU(torch.nn.Module):
|
|
|
|
if time_shift:
|
|
|
|
if time_shift:
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
|
|
hidden_sz = 3 * config.n_embd
|
|
|
|
hidden_sz = 3 * config.n_ffn
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
@ -202,10 +266,10 @@ class MHA_pro(nn.Module):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.layer_id = layer_id
|
|
|
|
self.layer_id = layer_id
|
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
assert config.n_attn % config.n_head == 0
|
|
|
|
self.n_head = config.n_head
|
|
|
|
self.n_head = config.n_head
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.head_size = config.n_embd // config.n_head
|
|
|
|
self.head_size = config.n_attn // config.n_head
|
|
|
|
|
|
|
|
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
|
|
|
|
self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
|
|
|
|
@ -214,16 +278,16 @@ class MHA_pro(nn.Module):
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
|
|
|
|
|
|
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.query = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.key = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.value = nn.Linear(config.n_embd, config.n_attn)
|
|
|
|
|
|
|
|
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
self.rotary_ndims = int(self.head_size * 0.5)
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
|
|
|
|
|
|
|
|
|
|
|
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
|
|
|
|
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
|
|
|
|
|
|
|
|
|
|
|
|
self.output = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
self.output = nn.Linear(config.n_attn, config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, T, C = x.size()
|
|
|
|
B, T, C = x.size()
|
|
|
|
@ -248,12 +312,12 @@ class MHA_pro(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
|
|
|
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
|
|
|
|
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask
|
|
|
|
att = F.softmax(att, dim = -1) # softmax
|
|
|
|
att = F.softmax(att, dim = -1) # softmax
|
|
|
|
att = att * w # time-weighting
|
|
|
|
att = att * w # time-weighting
|
|
|
|
att = self.head_mix(att) # talking heads
|
|
|
|
att = self.head_mix(att) # talking heads
|
|
|
|
|
|
|
|
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.output(x) * self.time_gamma[:T, :]
|
|
|
|
x = self.output(x) * self.time_gamma[:T, :]
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
@ -338,43 +402,11 @@ class GPT(nn.Module):
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.ctx_len = config.ctx_len
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.model_type == 'RWKV':
|
|
|
|
if self.config.model_type == 'RWKV': # improve orthogonal weight init
|
|
|
|
RWKV_Init(self, config)
|
|
|
|
ww = self.state_dict()
|
|
|
|
else:
|
|
|
|
for k in ww:
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
if 'tok_emb' in k:
|
|
|
|
|
|
|
|
if self.config.vocab_size > self.config.n_embd:
|
|
|
|
|
|
|
|
ww[k] *= math.sqrt(self.config.vocab_size)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
ww[k] *= math.sqrt(self.config.n_embd)
|
|
|
|
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
|
|
|
|
|
|
|
|
elif 'head.weight' in k:
|
|
|
|
|
|
|
|
ww[k] *= 0.4 # 0.4 is a safe choice // 0.8 might be better for chinese
|
|
|
|
|
|
|
|
elif 'blocks.' in k:
|
|
|
|
|
|
|
|
block_id = int(k.split('.')[1])
|
|
|
|
|
|
|
|
if 'receptance.weight' in k:
|
|
|
|
|
|
|
|
ww[k] *= 0 # init with zero matrix
|
|
|
|
|
|
|
|
elif 'attn.key.weight' in k:
|
|
|
|
|
|
|
|
ww[k] *= 0 # init with zero matrix
|
|
|
|
|
|
|
|
elif 'attn.output.weight' in k:
|
|
|
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
|
|
|
elif 'mlp.weight.weight' in k:
|
|
|
|
|
|
|
|
ww[k] *= 1 / pow(1+block_id, 0.5) # 0.5 ~ 0.7 gives similar results
|
|
|
|
|
|
|
|
elif 'attn.time_w' in k:
|
|
|
|
|
|
|
|
curve = torch.tensor([0.9 ** (self.config.ctx_len - 1 - i) for i in range(self.config.ctx_len)])
|
|
|
|
|
|
|
|
curve = curve * 2 + 0.7
|
|
|
|
|
|
|
|
for h in range(self.config.n_head):
|
|
|
|
|
|
|
|
if self.config.n_head > 1:
|
|
|
|
|
|
|
|
mix_strength = 1 - 1.2 * h / (self.config.n_head - 1) # mix_strength from 1 to -0.2
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
mix_strength = 0.5
|
|
|
|
|
|
|
|
ww[k][h] = (1 - mix_strength) + curve * mix_strength
|
|
|
|
|
|
|
|
# special tweaks because of time_shift
|
|
|
|
|
|
|
|
ww[k][h][self.config.ctx_len - 3] = (ww[k][h][self.config.ctx_len - 3] * 2 + 1) / 3
|
|
|
|
|
|
|
|
ww[k][h][self.config.ctx_len - 2] = (ww[k][h][self.config.ctx_len - 2] * 1 + 2) / 3
|
|
|
|
|
|
|
|
ww[k][h][self.config.ctx_len - 1] = 1
|
|
|
|
|
|
|
|
# print(k, h, mix_strength, ww[k][h])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
|
|
|
|
|
|
|
|
|
|
|
@ -383,15 +415,7 @@ class GPT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
|
def _init_weights(self, module):
|
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
if self.config.model_type == 'RWKV':
|
|
|
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
|
|
gain = 1.0
|
|
|
|
|
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
|
|
|
|
if module.weight.data.shape[0] > module.weight.data.shape[1]:
|
|
|
|
|
|
|
|
gain = math.sqrt(module.weight.data.shape[0] / module.weight.data.shape[1])
|
|
|
|
|
|
|
|
nn.init.orthogonal_(module.weight, gain=gain)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
|
module.bias.data.zero_()
|
|
|
|
module.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|