no message

main
BlinkDL 3 years ago
parent 6ff859db80
commit 2f33901c10

@ -57,6 +57,11 @@ elif TOKEN_MODE == 'pile':
# n_embd = 1024 # n_embd = 1024
# ctx_len = 1024 # ctx_len = 1024
# MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040'
# n_layer = 24
# n_embd = 2048
# ctx_len = 1024
os.environ['RWKV_FLOAT_MODE'] = 'fp32' # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment) os.environ['RWKV_FLOAT_MODE'] = 'fp32' # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment)
os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda' os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda'
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'

@ -278,7 +278,7 @@ class Block(nn.Module):
self.ln0 = nn.LayerNorm(config.n_embd) self.ln0 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000) self.ffnPre = RWKV_ChannelMix(config, 0)
else: else:
self.att = RWKV_TimeMix(config, layer_id) self.att = RWKV_TimeMix(config, layer_id)

Loading…
Cancel
Save