10% faster training

main
BlinkDL 3 years ago
parent c84e8fd952
commit c43a17cfb3

@ -111,14 +111,14 @@ __global__ void kernel_backward(const int B, const int T, const int C,
} }
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 1024) ); dim3 threadsPerBlock( min(C, 256) );
assert(B * C % threadsPerBlock.x == 0); assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x); dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
} }
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 1024) ); dim3 threadsPerBlock( min(C, 256) );
assert(B * C % threadsPerBlock.x == 0); assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x); dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);

@ -24,7 +24,7 @@ T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function): class WKV(torch.autograd.Function):
@staticmethod @staticmethod

Loading…
Cancel
Save