|
|
|
|
@ -68,9 +68,10 @@ class RWKV_ChannelMix(nn.Module):
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
|
|
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
|
|
|
|
|
hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R
|
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
|
self.receptance = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
@ -166,9 +167,10 @@ class GeGLU(torch.nn.Module):
|
|
|
|
|
def __init__(self, config, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
|
|
|
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
|
|
|
|
|
hidden_sz = 3 * config.n_embd
|
|
|
|
|
self.key = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
self.value = nn.Linear(config.n_embd, hidden_sz)
|
|
|
|
|
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
k = self.key(x)
|
|
|
|
|
|