From 6299c087a4bec5d0baf9f50b95021660743b0e43 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 30 Aug 2022 18:23:34 +0800 Subject: [PATCH] fixed VRAM consumpition --- RWKV-v4/src/model.py | 10 +-- RWKV-v4/src/model_run.py | 2 +- RWKV-v4/src/trainer.py | 12 +++- RWKV-v4/src/utils.py | 1 + RWKV-v4/train.py | 141 +++++++++++++++++++++------------------ 5 files changed, 93 insertions(+), 73 deletions(-) diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index bfc7ae2..2323407 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -19,7 +19,7 @@ print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') # CUDA Kernel ######################################################################################################## -T_MAX = 4096 # increase this if your ctx_len is long +T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load @@ -62,10 +62,10 @@ class WKV(torch.autograd.Function): assert T <= T_MAX assert B * C % min(C, 1024) == 0 w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda') - gu = torch.zeros((B, C), device='cuda') - gk = torch.zeros((B, T, C), device='cuda') - gv = torch.zeros((B, T, C), device='cuda') + gw = torch.zeros((B, C), device='cuda').contiguous() + gu = torch.zeros((B, C), device='cuda').contiguous() + gk = torch.zeros((B, T, C), device='cuda').contiguous() + gv = torch.zeros((B, T, C), device='cuda').contiguous() if os.environ['RWKV_FLOAT_MODE'] != 'fp32': wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) else: diff --git a/RWKV-v4/src/model_run.py b/RWKV-v4/src/model_run.py index aaac1b1..f68e8d2 100644 --- a/RWKV-v4/src/model_run.py +++ b/RWKV-v4/src/model_run.py @@ -19,7 +19,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs ######################################################################################################## if os.environ['RWKV_RUN_DEVICE'] == 'cuda': - T_MAX = 4096 # increase this if your ctx_len is long + T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load diff --git a/RWKV-v4/src/trainer.py b/RWKV-v4/src/trainer.py index 73b2422..74bbd99 100644 --- a/RWKV-v4/src/trainer.py +++ b/RWKV-v4/src/trainer.py @@ -13,6 +13,7 @@ import logging import datetime import math from pytorch_lightning.lite import LightningLite +import gc logger = logging.getLogger(__name__) torch.backends.cudnn.benchmark = True @@ -99,7 +100,9 @@ class Trainer(LightningLite): model, config = self.model, self.config raw_model = model.module if hasattr(self.model, "module") else model optimizer = raw_model.configure_optimizers(config) - model, optimizer = self.setup(model, optimizer) + model, optimizer = self.setup(model, optimizer) + gc.collect() + torch.cuda.empty_cache() print('[3]') def run_epoch(split): @@ -127,8 +130,11 @@ class Trainer(LightningLite): yyy, loss = model(x, y) # forward the model lossL2 = L2Wrap.apply(loss, yyy) - all_loss = [loss.clone() for _ in range(NUM_GPUS)] - torch.distributed.all_gather(all_loss, loss) + if os.environ['RWKV_DEEPSPEED'] == '0': + all_loss = [loss.clone()] + else: + all_loss = [loss.clone() for _ in range(NUM_GPUS)] + torch.distributed.all_gather(all_loss, loss) if is_train: # backprop and update the parameters model.zero_grad() diff --git a/RWKV-v4/src/utils.py b/RWKV-v4/src/utils.py index e624041..a73792c 100644 --- a/RWKV-v4/src/utils.py +++ b/RWKV-v4/src/utils.py @@ -82,6 +82,7 @@ class TOKENIZER(): else: from transformers import GPT2TokenizerFast self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) + self.vocab_size = len(self.tokenizer) else: self.charMode = True with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index e84c5c1..d8c6eba 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -63,6 +63,11 @@ os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32' +os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True + +if int(os.environ['RWKV_NUM_GPUS']) == 1 and os.environ['RWKV_FLOAT_MODE'] == 'fp32': # the only case where DeepSpeed is worse + os.environ['RWKV_DEEPSPEED'] = '0' + os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True ######################################################################################################## @@ -74,7 +79,7 @@ LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the train n_layer = 6 n_embd = 512 -ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is very long +ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better) @@ -187,69 +192,77 @@ if __name__ == '__main__': m_cfg.LOAD_MODEL = LOAD_MODEL m_cfg.MODEL_NAME = MODEL_NAME - from pytorch_lightning.strategies import DeepSpeedStrategy - - DEEPSPEED_CFG = { - "zero_allow_untested_optimizer":True, - "zero_optimization":{ - "stage":2, - "contiguous_gradients":True, - "overlap_comm":True, - "allgather_partitions":True, - "reduce_scatter":True, - "allgather_bucket_size":200000000, - "reduce_bucket_size":200000000, - "sub_group_size":1000000000000 - }, - "activation_checkpointing":{ - "partition_activations":False, - "cpu_checkpointing":False, - "contiguous_memory_optimization":False, - "synchronize_checkpoint_boundary":False - }, - "aio":{ - "block_size":1048576, - "queue_depth":8, - "single_submit":False, - "overlap_events":True, - "thread_count":1 - }, - "gradient_clipping": 1.0, - "gradient_accumulation_steps": 1, - } - if NUM_GPUS == 1: - DEEPSPEED_CFG['zero_optimization'] = { - "stage":1, # saves some VRAM - "contiguous_gradients":False, - "overlap_comm":False, - "allgather_partitions":False, - "reduce_scatter":False, - "allgather_bucket_size":200000000, - "reduce_bucket_size":200000000, - "sub_group_size":1000000000000 - } - - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': - DEEPSPEED_CFG["fp16"] = { - "fp16": True, - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 12, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - } - trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16) + if os.environ['RWKV_DEEPSPEED'] == '0': + if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16) + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16') + elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': + trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32) + else: + from pytorch_lightning.strategies import DeepSpeedStrategy - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - DEEPSPEED_CFG["bf16"] = { - "enabled": True + DEEPSPEED_CFG = { + "zero_allow_untested_optimizer":True, + "zero_optimization":{ + "stage":2, + "contiguous_gradients":True, + "overlap_comm":True, + "allgather_partitions":True, + "reduce_scatter":True, + "allgather_bucket_size":200000000, + "reduce_bucket_size":200000000, + "sub_group_size":1000000000000 + }, + "activation_checkpointing":{ + "partition_activations":False, + "cpu_checkpointing":False, + "contiguous_memory_optimization":False, + "synchronize_checkpoint_boundary":False + }, + "aio":{ + "block_size":1048576, + "queue_depth":8, + "single_submit":False, + "overlap_events":True, + "thread_count":1 + }, + "gradient_clipping": 1.0, + "gradient_accumulation_steps": 1, } - trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16') - - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': - trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32) - - print(trainer._strategy.config) - + if NUM_GPUS == 1: + DEEPSPEED_CFG['zero_optimization'] = { + "stage":1, # saves some VRAM + "contiguous_gradients":False, + "overlap_comm":False, + "allgather_partitions":False, + "reduce_scatter":False, + "allgather_bucket_size":200000000, + "reduce_bucket_size":200000000, + "sub_group_size":1000000000000 + } + + if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + DEEPSPEED_CFG["fp16"] = { + "fp16": True, + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 12, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + } + trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16) + + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + DEEPSPEED_CFG["bf16"] = { + "enabled": True + } + trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16') + + elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': + trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32) + + print(trainer._strategy.config) + trainer.run(m_cfg, train_dataset, None, tconf)