RWKV: now faster and less params

main 0.01
BlinkDL 4 years ago
parent 546114c6a5
commit 3b9005ea11

@ -68,9 +68,10 @@ class RWKV_ChannelMix(nn.Module):
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,0))
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
def forward(self, x):
@ -166,9 +167,10 @@ class GeGLU(torch.nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
hidden_sz = 3 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
def forward(self, x):
k = self.key(x)

Loading…
Cancel
Save