From 083f9504c619a2d2891d8a91c697d682e156ef06 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Thu, 28 Jul 2022 15:48:05 +0800 Subject: [PATCH] + bf16 mode (more stable) --- RWKV-v4/run.py | 4 +- RWKV-v4/src/model.py | 108 +++++++++++++++++++++++++++------------ RWKV-v4/src/model_run.py | 103 +++++++++++++++++++++++++------------ RWKV-v4/train.py | 55 ++++++++++++++++++-- RWKV-v4/verify.py | 10 +++- 5 files changed, 205 insertions(+), 75 deletions(-) diff --git a/RWKV-v4/run.py b/RWKV-v4/run.py index c6862a5..ef2ea47 100644 --- a/RWKV-v4/run.py +++ b/RWKV-v4/run.py @@ -3,7 +3,7 @@ ######################################################################################################## import numpy as np -import math +import math, os import time import types import copy @@ -18,6 +18,8 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200) ### Step 1: set model ################################################################################## +os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) + ctx_len = 1024 n_layer = 6 n_embd = 512 diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index 40c279c..125b57e 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -26,39 +26,74 @@ from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) -class WKV(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, w, u, k, v): - ctx.B = B - ctx.T = T - ctx.C = C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) - wkv_cuda.forward(B, T, C, w, u, k, v, y) - return y.half() - - @staticmethod - def backward(ctx, gy): - B = ctx.B - T = ctx.T - C = ctx.C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda') - gu = torch.zeros((B, C), device='cuda') - gk = torch.zeros((B, T, C), device='cuda') - gv = torch.zeros((B, T, C), device='cuda') - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - gw = torch.sum(gw, dim=0) - gu = torch.sum(gu, dim=0) - return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) +if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.half() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) +elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.bfloat16() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) @@ -336,7 +371,12 @@ class GPT(nn.Module): k = self.head_k(x)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() + + if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16() + x = self.head(x) + c else: x = self.head(x) diff --git a/RWKV-v4/src/model_run.py b/RWKV-v4/src/model_run.py index 7eb3809..b586035 100644 --- a/RWKV-v4/src/model_run.py +++ b/RWKV-v4/src/model_run.py @@ -5,7 +5,7 @@ import types import copy import torch -import math +import math, os from torch.nn import functional as F import torch.nn as nn @@ -25,39 +25,74 @@ from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) -class WKV(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, w, u, k, v): - ctx.B = B - ctx.T = T - ctx.C = C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) - wkv_cuda.forward(B, T, C, w, u, k, v, y) - return y.half() - - @staticmethod - def backward(ctx, gy): - B = ctx.B - T = ctx.T - C = ctx.C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda') - gu = torch.zeros((B, C), device='cuda') - gk = torch.zeros((B, T, C), device='cuda') - gv = torch.zeros((B, T, C), device='cuda') - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - gw = torch.sum(gw, dim=0) - gu = torch.sum(gu, dim=0) - return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) +if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.half() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) +elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.bfloat16() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index 4c670d3..1b5ee67 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -6,6 +6,8 @@ import os os.environ['USE_WANDB'] = '0' # 0 = False, 1 = True +os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) + ### This is using DeepSpeed stage2 + FP16 ############################################################## # # Currently it's slow to initialize a new model. Hence I suggest this procedure for multi-GPU training: @@ -116,8 +118,8 @@ train_dataset = Dataset(open( if __name__ == '__main__': from src.trainer import Trainer, TrainerConfig - print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', - betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, ) + print('\nmodel', model_type, os.environ['RWKV_FLOAT_MODE'], 'epoch', n_epoch, 'batchsz', batch_size, 'betas', + betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, '\n') tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, @@ -132,9 +134,52 @@ if __name__ == '__main__': from pytorch_lightning.strategies import DeepSpeedStrategy - # you can set grad_norm_clip in deepspeed.json - - trainer = Trainer(strategy=DeepSpeedStrategy(config='deepspeed.json'), devices=NUM_GPUS, accelerator="gpu", precision=16) + DEEPSPEED_CFG = { + "zero_allow_untested_optimizer":True, + "zero_optimization":{ + "stage":2, + "contiguous_gradients":True, + "overlap_comm":True, + "allgather_partitions":True, + "reduce_scatter":True, + "allgather_bucket_size":200000000, + "reduce_bucket_size":200000000, + "sub_group_size":1000000000000 + }, + "activation_checkpointing":{ + "partition_activations":False, + "cpu_checkpointing":False, + "contiguous_memory_optimization":False, + "synchronize_checkpoint_boundary":False + }, + "aio":{ + "block_size":1048576, + "queue_depth":8, + "single_submit":False, + "overlap_events":True, + "thread_count":1 + }, + "gradient_clipping": 1.0, + "gradient_accumulation_steps": 1, + } + + if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + DEEPSPEED_CFG["fp16"] = { + "fp16": True, + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 12, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + } + trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16) + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + DEEPSPEED_CFG["bf16"] = { + "enabled": True + } + trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16') + print(trainer._strategy.config) trainer.run(m_cfg, train_dataset, None, tconf) diff --git a/RWKV-v4/verify.py b/RWKV-v4/verify.py index 75e8e55..5198c52 100644 --- a/RWKV-v4/verify.py +++ b/RWKV-v4/verify.py @@ -15,6 +15,8 @@ import torch from src.model_run import RWKV_RNN, RWKV_GPT from src.model import GPT, GPTConfig +os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) + ctx_len = 1024 n_layer = 6 n_embd = 512 @@ -27,7 +29,13 @@ tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ') ######################################################################################################## -model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda().half() +model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda() + +if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + model_train = model_train.half() +elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + model_train = model_train.bfloat16() + print('loading ' + model_name) m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE) model_train.load_state_dict(m2)