diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index cf1ccc4..f6945dd 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -67,8 +67,8 @@ os.environ['RWKV_FLOAT_MODE'] = 'bf16' os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True -if int(os.environ['RWKV_NUM_GPUS']) == 1: # turn off DeepSpeed for 1 GPU training - os.environ['RWKV_DEEPSPEED'] = '0' +if int(os.environ['RWKV_NUM_GPUS']) == 1: # Usually you don't need DeepSpeed for 1 GPU training. + os.environ['RWKV_DEEPSPEED'] = '0' # However, sometimes DeepSpeed saves VRAM even for 1 GPU training. So you shall try it. os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True