fixed VRAM consumpition

main
BlinkDL 3 years ago
parent cb520e0f15
commit 6299c087a4

@ -19,7 +19,7 @@ print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
# CUDA Kernel # CUDA Kernel
######################################################################################################## ########################################################################################################
T_MAX = 4096 # increase this if your ctx_len is long T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
@ -62,10 +62,10 @@ class WKV(torch.autograd.Function):
assert T <= T_MAX assert T <= T_MAX
assert B * C % min(C, 1024) == 0 assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda') gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda') gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda') gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda') gv = torch.zeros((B, T, C), device='cuda').contiguous()
if os.environ['RWKV_FLOAT_MODE'] != 'fp32': if os.environ['RWKV_FLOAT_MODE'] != 'fp32':
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
else: else:

@ -19,7 +19,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs
######################################################################################################## ########################################################################################################
if os.environ['RWKV_RUN_DEVICE'] == 'cuda': if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
T_MAX = 4096 # increase this if your ctx_len is long T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load

@ -13,6 +13,7 @@ import logging
import datetime import datetime
import math import math
from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite import LightningLite
import gc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -100,6 +101,8 @@ class Trainer(LightningLite):
raw_model = model.module if hasattr(self.model, "module") else model raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config) optimizer = raw_model.configure_optimizers(config)
model, optimizer = self.setup(model, optimizer) model, optimizer = self.setup(model, optimizer)
gc.collect()
torch.cuda.empty_cache()
print('[3]') print('[3]')
def run_epoch(split): def run_epoch(split):
@ -127,6 +130,9 @@ class Trainer(LightningLite):
yyy, loss = model(x, y) # forward the model yyy, loss = model(x, y) # forward the model
lossL2 = L2Wrap.apply(loss, yyy) lossL2 = L2Wrap.apply(loss, yyy)
if os.environ['RWKV_DEEPSPEED'] == '0':
all_loss = [loss.clone()]
else:
all_loss = [loss.clone() for _ in range(NUM_GPUS)] all_loss = [loss.clone() for _ in range(NUM_GPUS)]
torch.distributed.all_gather(all_loss, loss) torch.distributed.all_gather(all_loss, loss)

@ -82,6 +82,7 @@ class TOKENIZER():
else: else:
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
else: else:
self.charMode = True self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:

@ -63,6 +63,11 @@ os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32' os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32'
os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True
if int(os.environ['RWKV_NUM_GPUS']) == 1 and os.environ['RWKV_FLOAT_MODE'] == 'fp32': # the only case where DeepSpeed is worse
os.environ['RWKV_DEEPSPEED'] = '0'
os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True
######################################################################################################## ########################################################################################################
@ -74,7 +79,7 @@ LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the train
n_layer = 6 n_layer = 6
n_embd = 512 n_embd = 512
ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is very long ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better) model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better)
@ -187,6 +192,14 @@ if __name__ == '__main__':
m_cfg.LOAD_MODEL = LOAD_MODEL m_cfg.LOAD_MODEL = LOAD_MODEL
m_cfg.MODEL_NAME = MODEL_NAME m_cfg.MODEL_NAME = MODEL_NAME
if os.environ['RWKV_DEEPSPEED'] == '0':
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
else:
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
DEEPSPEED_CFG = { DEEPSPEED_CFG = {

Loading…
Cancel
Save