diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index ef1b6cf..e84c5c1 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -217,7 +217,18 @@ if __name__ == '__main__': "gradient_clipping": 1.0, "gradient_accumulation_steps": 1, } - + if NUM_GPUS == 1: + DEEPSPEED_CFG['zero_optimization'] = { + "stage":1, # saves some VRAM + "contiguous_gradients":False, + "overlap_comm":False, + "allgather_partitions":False, + "reduce_scatter":False, + "allgather_bucket_size":200000000, + "reduce_bucket_size":200000000, + "sub_group_size":1000000000000 + } + if os.environ['RWKV_FLOAT_MODE'] == 'fp16': DEEPSPEED_CFG["fp16"] = { "fp16": True,