fixed VRAM consumpition

main
BlinkDL 3 years ago
parent cb520e0f15
commit 6299c087a4

@ -19,7 +19,7 @@ print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
# CUDA Kernel # CUDA Kernel
######################################################################################################## ########################################################################################################
T_MAX = 4096 # increase this if your ctx_len is long T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
@ -62,10 +62,10 @@ class WKV(torch.autograd.Function):
assert T <= T_MAX assert T <= T_MAX
assert B * C % min(C, 1024) == 0 assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda') gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda') gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda') gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda') gv = torch.zeros((B, T, C), device='cuda').contiguous()
if os.environ['RWKV_FLOAT_MODE'] != 'fp32': if os.environ['RWKV_FLOAT_MODE'] != 'fp32':
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
else: else:

@ -19,7 +19,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs
######################################################################################################## ########################################################################################################
if os.environ['RWKV_RUN_DEVICE'] == 'cuda': if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
T_MAX = 4096 # increase this if your ctx_len is long T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load

@ -13,6 +13,7 @@ import logging
import datetime import datetime
import math import math
from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite import LightningLite
import gc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -99,7 +100,9 @@ class Trainer(LightningLite):
model, config = self.model, self.config model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config) optimizer = raw_model.configure_optimizers(config)
model, optimizer = self.setup(model, optimizer) model, optimizer = self.setup(model, optimizer)
gc.collect()
torch.cuda.empty_cache()
print('[3]') print('[3]')
def run_epoch(split): def run_epoch(split):
@ -127,8 +130,11 @@ class Trainer(LightningLite):
yyy, loss = model(x, y) # forward the model yyy, loss = model(x, y) # forward the model
lossL2 = L2Wrap.apply(loss, yyy) lossL2 = L2Wrap.apply(loss, yyy)
all_loss = [loss.clone() for _ in range(NUM_GPUS)] if os.environ['RWKV_DEEPSPEED'] == '0':
torch.distributed.all_gather(all_loss, loss) all_loss = [loss.clone()]
else:
all_loss = [loss.clone() for _ in range(NUM_GPUS)]
torch.distributed.all_gather(all_loss, loss)
if is_train: # backprop and update the parameters if is_train: # backprop and update the parameters
model.zero_grad() model.zero_grad()

@ -82,6 +82,7 @@ class TOKENIZER():
else: else:
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
else: else:
self.charMode = True self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:

@ -63,6 +63,11 @@ os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32' os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32'
os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True
if int(os.environ['RWKV_NUM_GPUS']) == 1 and os.environ['RWKV_FLOAT_MODE'] == 'fp32': # the only case where DeepSpeed is worse
os.environ['RWKV_DEEPSPEED'] = '0'
os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True
######################################################################################################## ########################################################################################################
@ -74,7 +79,7 @@ LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the train
n_layer = 6 n_layer = 6
n_embd = 512 n_embd = 512
ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is very long ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better) model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better)
@ -187,69 +192,77 @@ if __name__ == '__main__':
m_cfg.LOAD_MODEL = LOAD_MODEL m_cfg.LOAD_MODEL = LOAD_MODEL
m_cfg.MODEL_NAME = MODEL_NAME m_cfg.MODEL_NAME = MODEL_NAME
from pytorch_lightning.strategies import DeepSpeedStrategy if os.environ['RWKV_DEEPSPEED'] == '0':
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
DEEPSPEED_CFG = { trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
"zero_allow_untested_optimizer":True, elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
"zero_optimization":{ trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
"stage":2, elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
"contiguous_gradients":True, trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
"overlap_comm":True, else:
"allgather_partitions":True, from pytorch_lightning.strategies import DeepSpeedStrategy
"reduce_scatter":True,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
},
"activation_checkpointing":{
"partition_activations":False,
"cpu_checkpointing":False,
"contiguous_memory_optimization":False,
"synchronize_checkpoint_boundary":False
},
"aio":{
"block_size":1048576,
"queue_depth":8,
"single_submit":False,
"overlap_events":True,
"thread_count":1
},
"gradient_clipping": 1.0,
"gradient_accumulation_steps": 1,
}
if NUM_GPUS == 1:
DEEPSPEED_CFG['zero_optimization'] = {
"stage":1, # saves some VRAM
"contiguous_gradients":False,
"overlap_comm":False,
"allgather_partitions":False,
"reduce_scatter":False,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
}
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
DEEPSPEED_CFG["fp16"] = {
"fp16": True,
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 12,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': DEEPSPEED_CFG = {
DEEPSPEED_CFG["bf16"] = { "zero_allow_untested_optimizer":True,
"enabled": True "zero_optimization":{
"stage":2,
"contiguous_gradients":True,
"overlap_comm":True,
"allgather_partitions":True,
"reduce_scatter":True,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
},
"activation_checkpointing":{
"partition_activations":False,
"cpu_checkpointing":False,
"contiguous_memory_optimization":False,
"synchronize_checkpoint_boundary":False
},
"aio":{
"block_size":1048576,
"queue_depth":8,
"single_submit":False,
"overlap_events":True,
"thread_count":1
},
"gradient_clipping": 1.0,
"gradient_accumulation_steps": 1,
} }
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16') if NUM_GPUS == 1:
DEEPSPEED_CFG['zero_optimization'] = {
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': "stage":1, # saves some VRAM
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32) "contiguous_gradients":False,
"overlap_comm":False,
print(trainer._strategy.config) "allgather_partitions":False,
"reduce_scatter":False,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
}
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
DEEPSPEED_CFG["fp16"] = {
"fp16": True,
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 12,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
DEEPSPEED_CFG["bf16"] = {
"enabled": True
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32)
print(trainer._strategy.config)
trainer.run(m_cfg, train_dataset, None, tconf) trainer.run(m_cfg, train_dataset, None, tconf)

Loading…
Cancel
Save