diff --git a/src/model.py b/src/model.py index f1c4719..7cebae9 100644 --- a/src/model.py +++ b/src/model.py @@ -68,9 +68,10 @@ class RWKV_ChannelMix(nn.Module): self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,0)) - self.key = nn.Linear(config.n_embd, 3 * config.n_embd) - self.value = nn.Linear(config.n_embd, 3 * config.n_embd) - self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R + self.key = nn.Linear(config.n_embd, hidden_sz) + self.value = nn.Linear(config.n_embd, hidden_sz) + self.weight = nn.Linear(hidden_sz, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd) def forward(self, x): @@ -166,9 +167,10 @@ class GeGLU(torch.nn.Module): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id - self.key = nn.Linear(config.n_embd, 3 * config.n_embd) - self.value = nn.Linear(config.n_embd, 3 * config.n_embd) - self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + hidden_sz = 3 * config.n_embd + self.key = nn.Linear(config.n_embd, hidden_sz) + self.value = nn.Linear(config.n_embd, hidden_sz) + self.weight = nn.Linear(hidden_sz, config.n_embd) def forward(self, x): k = self.key(x)