From 61b7c429df4e64d1c970c25f7d19e058c8d633fb Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sat, 30 Jul 2022 22:59:41 +0800 Subject: [PATCH] no message --- RWKV-v4/run.py | 3 +- RWKV-v4/src/model_run.py | 157 ++++++++++++++++++++------------------- RWKV-v4/verify.py | 6 +- 3 files changed, 84 insertions(+), 82 deletions(-) diff --git a/RWKV-v4/run.py b/RWKV-v4/run.py index cd557c7..b7e4792 100644 --- a/RWKV-v4/run.py +++ b/RWKV-v4/run.py @@ -18,6 +18,8 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200) ### Step 1: set model ################################################################################## os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' or 'fp16' +os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda' +RUN_DEVICE = os.environ['RWKV_RUN_DEVICE'] ctx_len = 1024 n_layer = 6 @@ -45,7 +47,6 @@ else: ### Step 3: other config ############################################################################### -RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' DEBUG_DEBUG = False # True False - show softmax output NUM_TRIALS = 999 diff --git a/RWKV-v4/src/model_run.py b/RWKV-v4/src/model_run.py index 83b6b5d..1e68a79 100644 --- a/RWKV-v4/src/model_run.py +++ b/RWKV-v4/src/model_run.py @@ -18,84 +18,85 @@ DEBUG_TIME = False # True False - show trained time-coeffs # CUDA Kernel ######################################################################################################## -T_MAX = 4096 # increase this if your ctx_len is long -# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice - -from torch.utils.cpp_extension import load -wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], - verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) - -if os.environ['RWKV_FLOAT_MODE'] == 'fp16': - class WKV(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, w, u, k, v): - ctx.B = B - ctx.T = T - ctx.C = C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) - wkv_cuda.forward(B, T, C, w, u, k, v, y) - return y.half() - - @staticmethod - def backward(ctx, gy): - B = ctx.B - T = ctx.T - C = ctx.C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda') - gu = torch.zeros((B, C), device='cuda') - gk = torch.zeros((B, T, C), device='cuda') - gv = torch.zeros((B, T, C), device='cuda') - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - gw = torch.sum(gw, dim=0) - gu = torch.sum(gu, dim=0) - return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) -elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - class WKV(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, w, u, k, v): - ctx.B = B - ctx.T = T - ctx.C = C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) - wkv_cuda.forward(B, T, C, w, u, k, v, y) - return y.bfloat16() - - @staticmethod - def backward(ctx, gy): - B = ctx.B - T = ctx.T - C = ctx.C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda') - gu = torch.zeros((B, C), device='cuda') - gk = torch.zeros((B, T, C), device='cuda') - gv = torch.zeros((B, T, C), device='cuda') - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - gw = torch.sum(gw, dim=0) - gu = torch.sum(gu, dim=0) - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) - -def RUN_CUDA(B, T, C, w, u, k, v): - return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) +if os.environ['RWKV_RUN_DEVICE'] == 'cuda': + T_MAX = 4096 # increase this if your ctx_len is long + # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + + from torch.utils.cpp_extension import load + wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) + + if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.half() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + return y.bfloat16() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda') + gu = torch.zeros((B, C), device='cuda') + gk = torch.zeros((B, T, C), device='cuda') + gv = torch.zeros((B, T, C), device='cuda') + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + + def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) ############################################################################################################ diff --git a/RWKV-v4/verify.py b/RWKV-v4/verify.py index 5198c52..513ef37 100644 --- a/RWKV-v4/verify.py +++ b/RWKV-v4/verify.py @@ -9,14 +9,14 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200) import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" -RUN_DEVICE = 'cuda' +os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) +os.environ['RWKV_RUN_DEVICE'] = 'cuda' +RUN_DEVICE = os.environ['RWKV_RUN_DEVICE'] import torch from src.model_run import RWKV_RNN, RWKV_GPT from src.model import GPT, GPTConfig -os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) - ctx_len = 1024 n_layer = 6 n_embd = 512