From dc7e0802d0e32441d7fab6e892a64d01a35c471f Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 2 Sep 2022 13:57:29 +0800 Subject: [PATCH] faster --- RWKV-v4/cuda/wkv_cuda.cu | 4 ++-- RWKV-v4/src/model_run.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/RWKV-v4/cuda/wkv_cuda.cu b/RWKV-v4/cuda/wkv_cuda.cu index 8851c89..6acd0f3 100644 --- a/RWKV-v4/cuda/wkv_cuda.cu +++ b/RWKV-v4/cuda/wkv_cuda.cu @@ -111,14 +111,14 @@ __global__ void kernel_backward(const int B, const int T, const int C, } void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { - dim3 threadsPerBlock( min(C, 256) ); + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_forward<<>>(B, T, C, w, u, k, v, y); } void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { - dim3 threadsPerBlock( min(C, 256) ); + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); diff --git a/RWKV-v4/src/model_run.py b/RWKV-v4/src/model_run.py index ae45c53..16c7e5d 100644 --- a/RWKV-v4/src/model_run.py +++ b/RWKV-v4/src/model_run.py @@ -24,7 +24,7 @@ if os.environ['RWKV_RUN_DEVICE'] == 'cuda': from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], - verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) + verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) class WKV(torch.autograd.Function): @staticmethod