diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 3a0eb3b..ee5ce8e 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -167,6 +167,8 @@ class MyDataset(Dataset): magic_prime = 162165671 elif args.ctx_len == 4096: magic_prime = 81082817 + elif args.ctx_len == 8192: + magic_prime = 40541399 data = self.data_pile else: ii = ii // 2 diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index eede9b1..873bd52 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -166,13 +166,11 @@ if __name__ == "__main__": elif args.ctx_len == 4096: args.magic_prime = 81082817 args.epoch_count = 2010 + elif args.ctx_len == 8192: + args.magic_prime = 40541399 + args.epoch_count = 1005 if args.my_pile_shift < 0: - if args.ctx_len == 1024: - args.my_pile_shift = 0 - elif args.ctx_len == 2048: - args.my_pile_shift = 512 - elif args.ctx_len == 4096: - args.my_pile_shift = 768 + args.my_pile_shift = 0 if magic_prime_bak > 0: args.magic_prime = magic_prime_bak