|
|
|
@ -278,7 +278,7 @@ class Block(nn.Module):
|
|
|
|
self.ln0 = nn.LayerNorm(config.n_embd)
|
|
|
|
self.ln0 = nn.LayerNorm(config.n_embd)
|
|
|
|
|
|
|
|
|
|
|
|
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
|
|
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
|
|
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
|
|
|
|
self.ffnPre = RWKV_ChannelMix(config, 0)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.att = RWKV_TimeMix(config, layer_id)
|
|
|
|
self.att = RWKV_TimeMix(config, layer_id)
|
|
|
|
|
|
|
|
|
|
|
|
|