diff --git a/RWKV-v4neo/cuda/wkv_cuda_bf16.cu b/RWKV-v4neo/cuda/wkv_cuda_bf16.cu new file mode 100644 index 0000000..5b4e4e8 --- /dev/null +++ b/RWKV-v4neo/cuda/wkv_cuda_bf16.cu @@ -0,0 +1,132 @@ +#include +#include +#include "ATen/ATen.h" +#define MIN_VALUE (-1e38) +typedef at::BFloat16 bf16; + +__global__ void kernel_forward(const int B, const int T, const int C, + const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, + bf16 *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +__global__ void kernel_backward(const int B, const int T, const int C, + const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, + const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy, + bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + const bf16 *__restrict__ const y = _y + _offset; + const bf16 *__restrict__ const gy = _gy + _offset; + bf16 *__restrict__ const gk = _gk + _offset; + bf16 *__restrict__ const gv = _gv + _offset; + + float q[Tmax], r[Tmax]; + + float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + const float qq = float(gy[ii]) / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = bf16(gu); + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + const float qq = q[i]; + const float rr = r[i]; + + float e1 = qq * exp(rr); + float e2 = exp(kk + pp); + gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); + gv[ii] = bf16(e1 + e2 * aa); + + const float ww = w + pp; + const float www = rr - u - kk; + const float p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/RWKV-v4neo/cuda/wkv_op_bf16.cpp b/RWKV-v4neo/cuda/wkv_op_bf16.cpp new file mode 100644 index 0000000..5783416 --- /dev/null +++ b/RWKV-v4neo/cuda/wkv_op_bf16.cpp @@ -0,0 +1,25 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); +void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, + torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), + gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 71ddc34..6637992 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -41,66 +41,93 @@ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM! from torch.utils.cpp_extension import load -wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) - - -class WKV(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, w, u, k, v): - ctx.B = B - ctx.T = T - ctx.C = C - assert T <= T_MAX - assert B * C % min(C, 32) == 0 - if "32" in os.environ["RWKV_FLOAT_MODE"]: - w = -torch.exp(w.contiguous()) +if os.environ["RWKV_FLOAT_MODE"] == "bf16": + wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-lineinfo", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + w = -torch.exp(w.float().contiguous()) u = u.contiguous() k = k.contiguous() v = v.contiguous() - else: - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) - wkv_cuda.forward(B, T, C, w, u, k, v, y) - ctx.save_for_backward(w, u, k, v, y) - if "32" in os.environ["RWKV_FLOAT_MODE"]: + y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + ctx.save_for_backward(w, u, k, v, y) return y - elif os.environ["RWKV_FLOAT_MODE"] == "fp16": - return y.half() - elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - return y.bfloat16() - - @staticmethod - def backward(ctx, gy): - B = ctx.B - T = ctx.T - C = ctx.C - assert T <= T_MAX - assert B * C % min(C, 32) == 0 - w, u, k, v, y = ctx.saved_tensors - gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) - gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) - if "32" in os.environ["RWKV_FLOAT_MODE"]: + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + w, u, k, v, y = ctx.saved_tensors + gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) - else: - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) - del w - del u - del k - del v - del y - gw = torch.sum(gw, dim=0) - gu = torch.sum(gu, dim=0) - if "32" in os.environ["RWKV_FLOAT_MODE"]: + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) return (None, None, None, gw, gu, gk, gv) - elif os.environ["RWKV_FLOAT_MODE"] == "fp16": - return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) - elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) +else: + wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + if "32" in os.environ["RWKV_FLOAT_MODE"]: + w = -torch.exp(w.contiguous()) + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + ctx.save_for_backward(w, u, k, v, y) + if "32" in os.environ["RWKV_FLOAT_MODE"]: + return y + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + return y.half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + return y.bfloat16() + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + w, u, k, v, y = ctx.saved_tensors + gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) + gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) + gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + if "32" in os.environ["RWKV_FLOAT_MODE"]: + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + if "32" in os.environ["RWKV_FLOAT_MODE"]: + return (None, None, None, gw, gu, gk, gv) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) def RUN_CUDA(B, T, C, w, u, k, v):