From c43a17cfb3968c4d508bd0427b80c666ead1c6e4 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 2 Sep 2022 12:48:20 +0800 Subject: [PATCH] 10% faster training --- RWKV-v4/cuda/wkv_cuda.cu | 4 ++-- RWKV-v4/src/model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/RWKV-v4/cuda/wkv_cuda.cu b/RWKV-v4/cuda/wkv_cuda.cu index 720317c..8851c89 100644 --- a/RWKV-v4/cuda/wkv_cuda.cu +++ b/RWKV-v4/cuda/wkv_cuda.cu @@ -111,14 +111,14 @@ __global__ void kernel_backward(const int B, const int T, const int C, } void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { - dim3 threadsPerBlock( min(C, 1024) ); + dim3 threadsPerBlock( min(C, 256) ); assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_forward<<>>(B, T, C, w, u, k, v, y); } void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { - dim3 threadsPerBlock( min(C, 1024) ); + dim3 threadsPerBlock( min(C, 256) ); assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index 0664ccc..47c201f 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -24,7 +24,7 @@ T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] from torch.utils.cpp_extension import load wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], - verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}']) + verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) class WKV(torch.autograd.Function): @staticmethod