You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RWKV-LM/RWKV-v3/cuda/timex_cuda.cu

173 lines
5.5 KiB
Plaintext

#include <stdio.h>
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
const F eps, const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BF;
const int t = threadIdx.x << 2;
__shared__ F ww[Tmax];
__shared__ F kk[Tmax * BF];
F4(ww, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BF; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BF];
#pragma unroll
for (int j = 0; j < BF; j++) {
s[j] = {eps, eps, eps, eps};
}
const F *__restrict__ const w = ww + T - t - 4;
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BF; j++) {
const F x = kk[u + Tmax * j];
s[j].x += w[u + 3] * x;
s[j].y += w[u + 2] * x;
s[j].z += w[u + 1] * x;
s[j].w += w[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BF; j++) {
const F *__restrict__ const k = kk + Tmax * j;
s[j].y += w[t + 3] * k[t + 1];
s[j].z += w[t + 2] * k[t + 1];
s[j].z += w[t + 3] * k[t + 2];
s[j].w += w[t + 1] * k[t + 1];
s[j].w += w[t + 2] * k[t + 2];
s[j].w += w[t + 3] * k[t + 3];
F4(x, t + T * (i + ij * j)) = s[j];
}
}
template <typename F>
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int t = threadIdx.x << 2;
__shared__ F k[Tmax];
__shared__ F gg[Tmax];
F4(k, t) = F4(__k, t + T * i);
F4(gg, t) = F4(__gwk, t + T * i);
__syncthreads();
float4 s = {0, 0, 0, 0};
const F *__restrict__ const g = gg + T - t - 4;
for (int u = 0; u <= t; u++) {
F x = k[u];
s.x += g[u + 3] * x;
s.y += g[u + 2] * x;
s.z += g[u + 1] * x;
s.w += g[u + 0] * x;
}
s.y += g[t + 3] * k[t + 1];
s.z += g[t + 2] * k[t + 1];
s.z += g[t + 3] * k[t + 2];
s.w += g[t + 1] * k[t + 1];
s.w += g[t + 2] * k[t + 2];
s.w += g[t + 3] * k[t + 3];
F4(gw, t + T * i) = s;
}
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
}
template <typename F>
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
F *__restrict__ const gw, F *__restrict__ const gk,
const int B, const int C, const int T) {
const int i = blockIdx.y;
const int ij = (B * C) / BB;
const int t = threadIdx.x << 2;
__shared__ F w[Tmax];
__shared__ F kk[Tmax * BB];
__shared__ F gg[Tmax * BB];
F4(w, t) = F4(__w, t + T * (i % C));
#pragma unroll
for (int j = 0; j < BB; j++) {
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
}
__syncthreads();
float4 s[BB];
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = 0; u <= t; u++) {
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
F x = kk[u + Tmax * j];
s[j].x += g[u + 3] * x;
s[j].y += g[u + 2] * x;
s[j].z += g[u + 1] * x;
s[j].w += g[u + 0] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const k = kk + Tmax * j;
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
s[j].y += g[t + 3] * k[t + 1];
s[j].z += g[t + 2] * k[t + 1];
s[j].z += g[t + 3] * k[t + 2];
s[j].w += g[t + 1] * k[t + 1];
s[j].w += g[t + 2] * k[t + 2];
s[j].w += g[t + 3] * k[t + 3];
F4(gw, t + T * (i + ij * j)) = s[j];
}
#pragma unroll
for (int j = 0; j < BB; j++) {
s[j] = {0, 0, 0, 0};
}
for (int u = t + 3; u < T; u++) {
F x = w[u];
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - u] * x;
s[j].y += g[3 - u] * x;
s[j].z += g[4 - u] * x;
s[j].w += g[5 - u] * x;
}
}
#pragma unroll
for (int j = 0; j < BB; j++) {
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
s[j].x += g[2 - t] * w[t + 0];
s[j].x += g[1 - t] * w[t + 1];
s[j].x += g[0 - t] * w[t + 2];
s[j].y += g[2 - t] * w[t + 1];
s[j].y += g[1 - t] * w[t + 2];
s[j].z += g[2 - t] * w[t + 2];
F4(gk, t + T * (i + ij * j)) = s[j];
}
}
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
dim3 gridDim(1, B * C / BB);
dim3 blockDim(T >> 2);
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
}