better cuda kernel

main
BlinkDL 3 years ago
parent 760db55fa6
commit 93d671c287

@ -18,28 +18,33 @@ __global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const v = _v + _offset; const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset; F *__restrict__ const y = _y + _offset;
F p = 0, q = 0, o = MIN_VALUE; // aa and bb are running sums divided by exp(pp) (to avoid overflow)
// p and q are running sums divided by exp(o) (to avoid overflows) F aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) { for (int i = 0; i < T; i++) {
const int ii = i * C; const int ii = i * C;
const F kk = k[ii];
F no = max(o, u + k[ii]); const F vv = v[ii];
F A = exp(o - no);
F B = exp(u + k[ii] - no); F ww = u + kk;
y[ii] = (A * p + B * v[ii]) / (A * q + B); F p = max(pp, ww);
F e1 = exp(pp - p);
no = max(w + o, k[ii]); F e2 = exp(ww - p);
A = exp(w + o - no); y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
B = exp(k[ii] - no);
p = A * p + B * v[ii]; ww = w + pp;
q = A * q + B; p = max(ww, kk);
o = no; e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
} }
} }
template <typename F> template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C, __global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
const F *__restrict__ const _y, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C; const int _b = idx / C;
@ -50,64 +55,67 @@ __global__ void kernel_backward(const int B, const int T, const int C,
F w = _w[_c]; F w = _w[_c];
const F *__restrict__ const k = _k + _offset; const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset; const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const y = _y + _offset;
const F *__restrict__ const gy = _gy + _offset; const F *__restrict__ const gy = _gy + _offset;
F *__restrict__ const gk = _gk + _offset; F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset; F *__restrict__ const gv = _gv + _offset;
F y[Tmax], z[Tmax], zexp[Tmax]; F q[Tmax], r[Tmax];
F gw = 0, gu = 0; F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) { for (int i = 0; i < T; i++) {
const int ii = i * C; const int ii = i * C;
F no = max(o, k[ii] + u); const F kk = k[ii];
F A = exp(o - no); const F vv = v[ii];
F B = exp(k[ii] + u - no); const F yy = y[ii];
F num = A * p + B * v[ii]; F ww = u + kk;
F iden = 1 / (A * q + B); F p = max(pp, ww);
F e1 = exp(pp - p);
y[i] = num * iden; F e2 = exp(ww - p);
z[i] = iden; const F qq = gy[ii] / (e1 * bb + e2);
zexp[i] = k[ii] + u - no; gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; q[i] = qq;
gu += gy[ii] * (v[ii] - y[i]) * B * iden; r[i] = ww - p;
no = max(w + o, k[ii]); ww = w + pp;
A = exp(w + o - no); p = max(ww, kk);
B = exp(k[ii] - no); e1 = exp(ww - p);
dpdw = A * (p + dpdw); e2 = exp(kk - p);
dqdw = A * (q + dqdw); ga = e1 * (aa + ga);
p = A * p + B * v[ii]; gb = e1 * (bb + gb);
q = A * q + B; aa = e1 * aa + e2 * vv;
o = no; bb = e1 * bb + e2;
pp = p;
} }
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = gu;
F gp = 0, gq = 0; aa = 0, bb = 0, pp = MIN_VALUE;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) { for (int i = T - 1; i >= 0; i--) {
const int ii = i * C; const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]); const F kk = k[ii];
F B = exp(k[ii] + o); const F vv = v[ii];
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); const F yy = y[ii];
gv[ii] = A + B * gp; const F qq = q[i];
const F rr = r[i];
F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no); F e1 = qq * exp(rr);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); F e2 = exp(kk + pp);
gp = A * gp + B; gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
gq = A * gq - B * y[i]; gv[ii] = e1 + e2 * aa;
o = no;
const F ww = w + pp;
const F www = rr - u - kk;
const F p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
} }
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
} }
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
@ -117,9 +125,9 @@ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, f
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
} }
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0); assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x); dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
} }

@ -1,13 +1,13 @@
#include <torch/extension.h> #include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
} }
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

@ -41,7 +41,7 @@ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"]) wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function): class WKV(torch.autograd.Function):
@ -62,9 +62,9 @@ class WKV(torch.autograd.Function):
u = u.float().contiguous() u = u.float().contiguous()
k = k.float().contiguous() k = k.float().contiguous()
v = v.float().contiguous() v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y) wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]: if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16": elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
@ -79,15 +79,20 @@ class WKV(torch.autograd.Function):
C = ctx.C C = ctx.C
assert T <= T_MAX assert T <= T_MAX
assert B * C % min(C, 32) == 0 assert B * C % min(C, 32) == 0
w, u, k, v = ctx.saved_tensors w, u, k, v, y = ctx.saved_tensors
gw = torch.zeros((B, C), device=gy.device).contiguous() gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gu = torch.zeros((B, C), device=gy.device).contiguous() gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gk = torch.zeros((B, T, C), device=gy.device).contiguous() gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
gv = torch.zeros((B, T, C), device=gy.device).contiguous() gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
if "32" in os.environ["RWKV_FLOAT_MODE"]: if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
else: else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
del w
del u
del k
del v
del y
gw = torch.sum(gw, dim=0) gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0) gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]: if "32" in os.environ["RWKV_FLOAT_MODE"]:

Loading…
Cancel
Save