@ -63,6 +63,11 @@ os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
os . environ [ ' RWKV_FLOAT_MODE ' ] = ' bf16 ' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32'
os . environ [ ' RWKV_FLOAT_MODE ' ] = ' bf16 ' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32'
os . environ [ ' RWKV_DEEPSPEED ' ] = ' 1 ' # Use DeepSpeed? 0 = False, 1 = True
if int ( os . environ [ ' RWKV_NUM_GPUS ' ] ) == 1 and os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp32 ' : # the only case where DeepSpeed is worse
os . environ [ ' RWKV_DEEPSPEED ' ] = ' 0 '
os . environ [ ' USE_WANDB ' ] = ' 0 ' # wandb logging. 0 = False, 1 = True
os . environ [ ' USE_WANDB ' ] = ' 0 ' # wandb logging. 0 = False, 1 = True
########################################################################################################
########################################################################################################
@ -74,7 +79,7 @@ LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the train
n_layer = 6
n_layer = 6
n_embd = 512
n_embd = 512
ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is very long
ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer
model_type = ' RWKV ' # 'RWKV' or 'RWKV-ffnPre' (sometimes better)
model_type = ' RWKV ' # 'RWKV' or 'RWKV-ffnPre' (sometimes better)
@ -187,69 +192,77 @@ if __name__ == '__main__':
m_cfg . LOAD_MODEL = LOAD_MODEL
m_cfg . LOAD_MODEL = LOAD_MODEL
m_cfg . MODEL_NAME = MODEL_NAME
m_cfg . MODEL_NAME = MODEL_NAME
from pytorch_lightning . strategies import DeepSpeedStrategy
if os . environ [ ' RWKV_DEEPSPEED ' ] == ' 0 ' :
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
DEEPSPEED_CFG = {
trainer = Trainer ( devices = NUM_GPUS , accelerator = " gpu " , precision = 16 )
" zero_allow_untested_optimizer " : True ,
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
" zero_optimization " : {
trainer = Trainer ( devices = NUM_GPUS , accelerator = " gpu " , precision = ' bf16 ' )
" stage " : 2 ,
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp32 ' :
" contiguous_gradients " : True ,
trainer = Trainer ( devices = NUM_GPUS , accelerator = " gpu " , precision = 32 )
" overlap_comm " : True ,
else :
" allgather_partitions " : True ,
from pytorch_lightning . strategies import DeepSpeedStrategy
" reduce_scatter " : True ,
" allgather_bucket_size " : 200000000 ,
DEEPSPEED_CFG = {
" reduce_bucket_size " : 200000000 ,
" zero_allow_untested_optimizer " : True ,
" sub_group_size " : 1000000000000
" zero_optimization " : {
} ,
" stage " : 2 ,
" activation_checkpointing " : {
" contiguous_gradients " : True ,
" partition_activations " : False ,
" overlap_comm " : True ,
" cpu_checkpointing " : False ,
" allgather_partitions " : True ,
" contiguous_memory_optimization " : False ,
" reduce_scatter " : True ,
" synchronize_checkpoint_boundary " : False
" allgather_bucket_size " : 200000000 ,
} ,
" reduce_bucket_size " : 200000000 ,
" aio " : {
" sub_group_size " : 1000000000000
" block_size " : 1048576 ,
} ,
" queue_depth " : 8 ,
" activation_checkpointing " : {
" single_submit " : False ,
" partition_activations " : False ,
" overlap_events " : True ,
" cpu_checkpointing " : False ,
" thread_count " : 1
" contiguous_memory_optimization " : False ,
} ,
" synchronize_checkpoint_boundary " : False
" gradient_clipping " : 1.0 ,
} ,
" gradient_accumulation_steps " : 1 ,
" aio " : {
}
" block_size " : 1048576 ,
if NUM_GPUS == 1 :
" queue_depth " : 8 ,
DEEPSPEED_CFG [ ' zero_optimization ' ] = {
" single_submit " : False ,
" stage " : 1 , # saves some VRAM
" overlap_events " : True ,
" contiguous_gradients " : False ,
" thread_count " : 1
" overlap_comm " : False ,
} ,
" allgather_partitions " : False ,
" gradient_clipping " : 1.0 ,
" reduce_scatter " : False ,
" gradient_accumulation_steps " : 1 ,
" allgather_bucket_size " : 200000000 ,
" reduce_bucket_size " : 200000000 ,
" sub_group_size " : 1000000000000
}
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
DEEPSPEED_CFG [ " fp16 " ] = {
" fp16 " : True ,
" enabled " : True ,
" loss_scale " : 0 ,
" initial_scale_power " : 12 ,
" loss_scale_window " : 1000 ,
" hysteresis " : 2 ,
" min_loss_scale " : 1
}
trainer = Trainer ( strategy = DeepSpeedStrategy ( config = DEEPSPEED_CFG ) , devices = NUM_GPUS , accelerator = " gpu " , precision = 16 )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
DEEPSPEED_CFG [ " bf16 " ] = {
" enabled " : True
}
}
trainer = Trainer ( strategy = DeepSpeedStrategy ( config = DEEPSPEED_CFG ) , devices = NUM_GPUS , accelerator = " gpu " , precision = ' bf16 ' )
if NUM_GPUS == 1 :
DEEPSPEED_CFG [ ' zero_optimization ' ] = {
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp32 ' :
" stage " : 1 , # saves some VRAM
trainer = Trainer ( strategy = DeepSpeedStrategy ( config = DEEPSPEED_CFG ) , devices = NUM_GPUS , accelerator = " gpu " , precision = 32 )
" contiguous_gradients " : False ,
" overlap_comm " : False ,
print ( trainer . _strategy . config )
" allgather_partitions " : False ,
" reduce_scatter " : False ,
" allgather_bucket_size " : 200000000 ,
" reduce_bucket_size " : 200000000 ,
" sub_group_size " : 1000000000000
}
if os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp16 ' :
DEEPSPEED_CFG [ " fp16 " ] = {
" fp16 " : True ,
" enabled " : True ,
" loss_scale " : 0 ,
" initial_scale_power " : 12 ,
" loss_scale_window " : 1000 ,
" hysteresis " : 2 ,
" min_loss_scale " : 1
}
trainer = Trainer ( strategy = DeepSpeedStrategy ( config = DEEPSPEED_CFG ) , devices = NUM_GPUS , accelerator = " gpu " , precision = 16 )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' bf16 ' :
DEEPSPEED_CFG [ " bf16 " ] = {
" enabled " : True
}
trainer = Trainer ( strategy = DeepSpeedStrategy ( config = DEEPSPEED_CFG ) , devices = NUM_GPUS , accelerator = " gpu " , precision = ' bf16 ' )
elif os . environ [ ' RWKV_FLOAT_MODE ' ] == ' fp32 ' :
trainer = Trainer ( strategy = DeepSpeedStrategy ( config = DEEPSPEED_CFG ) , devices = NUM_GPUS , accelerator = " gpu " , precision = 32 )
print ( trainer . _strategy . config )
trainer . run ( m_cfg , train_dataset , None , tconf )
trainer . run ( m_cfg , train_dataset , None , tconf )