From 73b96705d7ac7824c7b443c2220aa3beeafde8f5 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 31 Aug 2022 09:19:27 +0800 Subject: [PATCH] + fp32 mode (slow but good for verification) --- RWKV-v4/src/model.py | 36 +++++++++++++++---------------- RWKV-v4/src/model_run.py | 46 +++++++++++++++++++++++----------------- RWKV-v4/src/trainer.py | 9 ++++++-- RWKV-v4/train.py | 18 ++++++++++++---- 4 files changed, 65 insertions(+), 44 deletions(-) diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index 2323407..0664ccc 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -34,25 +34,25 @@ class WKV(torch.autograd.Function): ctx.C = C assert T <= T_MAX assert B * C % min(C, 1024) == 0 - if os.environ['RWKV_FLOAT_MODE'] != 'fp32': - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - else: + if '32' in os.environ['RWKV_FLOAT_MODE']: w = -torch.exp(w.contiguous()) u = u.contiguous() k = k.contiguous() v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() ctx.save_for_backward(w, u, k, v) y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + if '32' in os.environ['RWKV_FLOAT_MODE']: + return y + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': return y.half() elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': return y.bfloat16() - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': - return y @staticmethod def backward(ctx, gy): @@ -66,18 +66,18 @@ class WKV(torch.autograd.Function): gu = torch.zeros((B, C), device='cuda').contiguous() gk = torch.zeros((B, T, C), device='cuda').contiguous() gv = torch.zeros((B, T, C), device='cuda').contiguous() - if os.environ['RWKV_FLOAT_MODE'] != 'fp32': - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - else: + if '32' in os.environ['RWKV_FLOAT_MODE']: wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + if '32' in os.environ['RWKV_FLOAT_MODE']: + return (None, None, None, gw, gu, gk, gv) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': - return (None, None, None, gw, gu, gk, gv) def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) @@ -356,12 +356,12 @@ class GPT(nn.Module): c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + if '32' in os.environ['RWKV_FLOAT_MODE']: + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16() - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': - c = c @ F.one_hot(idx, num_classes=self.config.vocab_size) x = self.head(x) + c else: diff --git a/RWKV-v4/src/model_run.py b/RWKV-v4/src/model_run.py index f68e8d2..ae45c53 100644 --- a/RWKV-v4/src/model_run.py +++ b/RWKV-v4/src/model_run.py @@ -34,25 +34,25 @@ if os.environ['RWKV_RUN_DEVICE'] == 'cuda': ctx.C = C assert T <= T_MAX assert B * C % min(C, 1024) == 0 - if os.environ['RWKV_FLOAT_MODE'] != 'fp32': - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - else: + if '32' in os.environ['RWKV_FLOAT_MODE']: w = -torch.exp(w.contiguous()) u = u.contiguous() k = k.contiguous() v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() ctx.save_for_backward(w, u, k, v) y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + if '32' in os.environ['RWKV_FLOAT_MODE']: + return y + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': return y.half() elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': return y.bfloat16() - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': - return y @staticmethod def backward(ctx, gy): @@ -62,22 +62,22 @@ if os.environ['RWKV_RUN_DEVICE'] == 'cuda': assert T <= T_MAX assert B * C % min(C, 1024) == 0 w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda') - gu = torch.zeros((B, C), device='cuda') - gk = torch.zeros((B, T, C), device='cuda') - gv = torch.zeros((B, T, C), device='cuda') - if os.environ['RWKV_FLOAT_MODE'] != 'fp32': - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - else: + gw = torch.zeros((B, C), device='cuda').contiguous() + gu = torch.zeros((B, C), device='cuda').contiguous() + gk = torch.zeros((B, T, C), device='cuda').contiguous() + gv = torch.zeros((B, T, C), device='cuda').contiguous() + if '32' in os.environ['RWKV_FLOAT_MODE']: wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + if '32' in os.environ['RWKV_FLOAT_MODE']: + return (None, None, None, gw, gu, gk, gv) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp32': - return (None, None, None, gw, gu, gk, gv) def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) @@ -222,7 +222,13 @@ class RWKV_GPT(nn.Module): c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float() + if '32' in os.environ['RWKV_FLOAT_MODE']: + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16() + x = self.head(x) + c else: x = self.head(x) diff --git a/RWKV-v4/src/trainer.py b/RWKV-v4/src/trainer.py index 74bbd99..645cc6c 100644 --- a/RWKV-v4/src/trainer.py +++ b/RWKV-v4/src/trainer.py @@ -16,9 +16,14 @@ from pytorch_lightning.lite import LightningLite import gc logger = logging.getLogger(__name__) + torch.backends.cudnn.benchmark = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True +if os.environ['RWKV_FLOAT_MODE'] == 'fp32': + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False +else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True class L2Wrap(torch.autograd.Function): @staticmethod diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index 24b5db7..fb0ff1b 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -12,9 +12,6 @@ from src.binidx import MMapIndexedDataset np.set_printoptions(precision=4, suppress=True, linewidth=200) logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) -torch.backends.cudnn.benchmark = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True # if False: # True False ---> Set to False if you don't understand it # print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n") @@ -61,7 +58,12 @@ if EXPRESS_PILE_MODE: # os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use -os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32' +# +# 'bf16' (fast & stable) +# 'fp16' (fast & will overflow after training a large model for very long. can be solved in the future) +# 'tf32' (decent speed & stable) +# 'fp32' (!!!very slow!!! only for verification) +os.environ['RWKV_FLOAT_MODE'] = 'bf16' os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True @@ -159,6 +161,14 @@ if EXPRESS_PILE_MODE: betas = (0.9, 0.999) MODEL_NAME = EXPRESS_PILE_MODEL_NAME +torch.backends.cudnn.benchmark = True +if os.environ['RWKV_FLOAT_MODE'] == 'fp32': + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False +else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + ######################################################################################################## # Load data ########################################################################################################