Compare commits
124 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
99a5933f54 | 3 years ago |
|
|
5b08ee1718 | 3 years ago |
|
|
4e962eb850 | 3 years ago |
|
|
8f428408a3 | 3 years ago |
|
|
a9007581d0 | 3 years ago |
|
|
79915b3696 | 3 years ago |
|
|
0c7cd08255 | 3 years ago |
|
|
3d43eaa1c8 | 3 years ago |
|
|
04a564d7d0 | 3 years ago |
|
|
5713df51ec | 3 years ago |
|
|
14d21f5a00 | 3 years ago |
|
|
4ca274aad7 | 3 years ago |
|
|
1945cb58ed | 3 years ago |
|
|
3d2b04ba0c | 3 years ago |
|
|
107f167b4a | 3 years ago |
|
|
ac4ba411e6 | 3 years ago |
|
|
099919058b | 3 years ago |
|
|
87fab90435 | 3 years ago |
|
|
13b8784502 | 3 years ago |
|
|
513d3eb552 | 3 years ago |
|
|
8b615ccc74 | 3 years ago |
|
|
1e0dba0421 | 3 years ago |
|
|
f8134fb96e | 3 years ago |
|
|
62fba64244 | 3 years ago |
|
|
123536b2a7 | 3 years ago |
|
|
decd8e29f5 | 3 years ago |
|
|
6d4dec7288 | 3 years ago |
|
|
8e99ac1138 | 3 years ago |
|
|
4056bfeba7 | 3 years ago |
|
|
d16e25661c | 3 years ago |
|
|
0ff2170277 | 3 years ago |
|
|
e615f1c718 | 3 years ago |
|
|
1430d4edcf | 3 years ago |
|
|
4378fe6b4f | 3 years ago |
|
|
f38b7e3574 | 3 years ago |
|
|
6739df885e | 3 years ago |
|
|
9f557219c4 | 3 years ago |
|
|
58e9d8d972 | 3 years ago |
|
|
1d72a48db0 | 3 years ago |
|
|
52ac194d54 | 3 years ago |
|
|
93d671c287 | 3 years ago |
|
|
760db55fa6 | 3 years ago |
|
|
6b59d8fee1 | 3 years ago |
|
|
fc047a20b1 | 3 years ago |
|
|
0c77cfbbee | 3 years ago |
|
|
904de99a14 | 3 years ago |
|
|
5f6ffc987a | 3 years ago |
|
|
02178d79c9 | 3 years ago |
|
|
ad1836b27f | 3 years ago |
|
|
404c593213 | 3 years ago |
|
|
9917078f93 | 3 years ago |
|
|
e0dc08a2ce | 3 years ago |
|
|
a3e6156136 | 3 years ago |
|
|
55f63f0aeb | 3 years ago |
|
|
114d677bc8 | 3 years ago |
|
|
d008cc6d8e | 3 years ago |
|
|
c13879ab97 | 3 years ago |
|
|
78579a00d2 | 3 years ago |
|
|
e6d9e4979a | 3 years ago |
|
|
8d72d882e4 | 3 years ago |
|
|
3f5ac97f77 | 3 years ago |
|
|
81aa6dda7b | 3 years ago |
|
|
366b000ee6 | 3 years ago |
|
|
11acd5e5b5 | 3 years ago |
|
|
374b086911 | 3 years ago |
|
|
7476c69f32 | 3 years ago |
|
|
6ed3a3db09 | 3 years ago |
|
|
e2ec7ae023 | 3 years ago |
|
|
71a46ca0f3 | 3 years ago |
|
|
b97c25b9e7 | 3 years ago |
|
|
3d15f41a16 | 3 years ago |
|
|
bbacb62b89 | 3 years ago |
|
|
c7b1900270 | 3 years ago |
|
|
038f06b996 | 3 years ago |
|
|
f03efd0218 | 3 years ago |
|
|
9721b8f9c5 | 3 years ago |
|
|
79aa59ff2b | 3 years ago |
|
|
8e63b75f2c | 3 years ago |
|
|
5c8eda8595 | 3 years ago |
|
|
13c6149205 | 3 years ago |
|
|
f6cb1a1947 | 3 years ago |
|
|
aeae6c8aac | 3 years ago |
|
|
b562097da1 | 3 years ago |
|
|
f79d082053 | 3 years ago |
|
|
8bf7061705 | 3 years ago |
|
|
dd3845752a | 3 years ago |
|
|
b4925900e7 | 3 years ago |
|
|
fcb0b9819d | 3 years ago |
|
|
be18c53fec | 3 years ago |
|
|
eac471da29 | 3 years ago |
|
|
d1bb270fb3 | 3 years ago |
|
|
5837ee32c4 | 3 years ago |
|
|
8e1130e12a | 3 years ago |
|
|
b2a240d73d | 3 years ago |
|
|
bc47cb9f1a | 3 years ago |
|
|
3461b2f6fb | 3 years ago |
|
|
295af9a517 | 3 years ago |
|
|
dc26998708 | 3 years ago |
|
|
75929cbbba | 3 years ago |
|
|
819f2730b2 | 3 years ago |
|
|
66c1dabb94 | 3 years ago |
|
|
379c97890b | 3 years ago |
|
|
83a4512b74 | 3 years ago |
|
|
59e6deeb58 | 3 years ago |
|
|
cf340264dc | 3 years ago |
|
|
935d8d3e87 | 3 years ago |
|
|
0d0cedfcd9 | 3 years ago |
|
|
511b7adb4f | 3 years ago |
|
|
aaf1341af7 | 3 years ago |
|
|
e64ce9b0ff | 3 years ago |
|
|
3e0f8054c6 | 3 years ago |
|
|
0131543e48 | 3 years ago |
|
|
315ce82e38 | 3 years ago |
|
|
9eb1a7b3d3 | 3 years ago |
|
|
7a7c06aed3 | 3 years ago |
|
|
2e5704097d | 3 years ago |
|
|
529be15c67 | 3 years ago |
|
|
5c73bccd5a | 3 years ago |
|
|
7bdfd1cb64 | 3 years ago |
|
|
eaac3f7e66 | 3 years ago |
|
|
14544aea94 | 3 years ago |
|
|
3fc16a86ed | 3 years ago |
|
|
eb3ca86f0b | 3 years ago |
|
|
23f64aeebc | 3 years ago |
|
After Width: | Height: | Size: 161 KiB |
|
After Width: | Height: | Size: 67 KiB |
|
Before Width: | Height: | Size: 1.1 MiB After Width: | Height: | Size: 410 KiB |
|
Before Width: | Height: | Size: 153 KiB After Width: | Height: | Size: 359 KiB |
|
After Width: | Height: | Size: 69 KiB |
|
After Width: | Height: | Size: 90 KiB |
|
After Width: | Height: | Size: 66 KiB |
@ -0,0 +1,132 @@
|
|||||||
|
#include <stdio.h>
|
||||||
|
#include <assert.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
#define MIN_VALUE (-1e38)
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||||
|
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||||
|
bf16 *__restrict__ const _y) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
float u = float(_u[_c]);
|
||||||
|
float w = _w[_c];
|
||||||
|
const bf16 *__restrict__ const k = _k + _offset;
|
||||||
|
const bf16 *__restrict__ const v = _v + _offset;
|
||||||
|
bf16 *__restrict__ const y = _y + _offset;
|
||||||
|
|
||||||
|
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||||
|
float aa = 0, bb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const float kk = float(k[ii]);
|
||||||
|
const float vv = float(v[ii]);
|
||||||
|
|
||||||
|
float ww = u + kk;
|
||||||
|
float p = max(pp, ww);
|
||||||
|
float e1 = exp(pp - p);
|
||||||
|
float e2 = exp(ww - p);
|
||||||
|
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
||||||
|
|
||||||
|
ww = w + pp;
|
||||||
|
p = max(ww, kk);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = exp(kk - p);
|
||||||
|
aa = e1 * aa + e2 * vv;
|
||||||
|
bb = e1 * bb + e2;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||||
|
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||||
|
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
|
||||||
|
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
float u = float(_u[_c]);
|
||||||
|
float w = _w[_c];
|
||||||
|
const bf16 *__restrict__ const k = _k + _offset;
|
||||||
|
const bf16 *__restrict__ const v = _v + _offset;
|
||||||
|
const bf16 *__restrict__ const y = _y + _offset;
|
||||||
|
const bf16 *__restrict__ const gy = _gy + _offset;
|
||||||
|
bf16 *__restrict__ const gk = _gk + _offset;
|
||||||
|
bf16 *__restrict__ const gv = _gv + _offset;
|
||||||
|
|
||||||
|
float q[Tmax], r[Tmax];
|
||||||
|
|
||||||
|
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const float kk = float(k[ii]);
|
||||||
|
const float vv = float(v[ii]);
|
||||||
|
const float yy = float(y[ii]);
|
||||||
|
|
||||||
|
float ww = u + kk;
|
||||||
|
float p = max(pp, ww);
|
||||||
|
float e1 = exp(pp - p);
|
||||||
|
float e2 = exp(ww - p);
|
||||||
|
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
||||||
|
gw += (ga - gb * yy) * e1 * qq;
|
||||||
|
gu += (vv - yy) * e2 * qq;
|
||||||
|
q[i] = qq;
|
||||||
|
r[i] = ww - p;
|
||||||
|
|
||||||
|
ww = w + pp;
|
||||||
|
p = max(ww, kk);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = exp(kk - p);
|
||||||
|
ga = e1 * (aa + ga);
|
||||||
|
gb = e1 * (bb + gb);
|
||||||
|
aa = e1 * aa + e2 * vv;
|
||||||
|
bb = e1 * bb + e2;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
const int _offsetBC = _b * C + _c;
|
||||||
|
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
||||||
|
_gu[_offsetBC] = bf16(gu);
|
||||||
|
|
||||||
|
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = T - 1; i >= 0; i--) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const float kk = float(k[ii]);
|
||||||
|
const float vv = float(v[ii]);
|
||||||
|
const float yy = float(y[ii]);
|
||||||
|
const float qq = q[i];
|
||||||
|
const float rr = r[i];
|
||||||
|
|
||||||
|
float e1 = qq * exp(rr);
|
||||||
|
float e2 = exp(kk + pp);
|
||||||
|
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
||||||
|
gv[ii] = bf16(e1 + e2 * aa);
|
||||||
|
|
||||||
|
const float ww = w + pp;
|
||||||
|
const float www = rr - u - kk;
|
||||||
|
const float p = max(ww, www);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = qq * exp(www - p);
|
||||||
|
aa = e1 * aa + e2;
|
||||||
|
bb = e1 * bb - e2 * yy;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
||||||
|
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
||||||
|
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||||
|
}
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
||||||
|
|
||||||
|
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||||
|
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
||||||
|
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||||
|
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
||||||
|
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &forward, "wkv forward");
|
||||||
|
m.def("backward", &backward, "wkv backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(wkv, m) {
|
||||||
|
m.def("forward", forward);
|
||||||
|
m.def("backward", backward);
|
||||||
|
}
|
||||||
@ -0,0 +1,104 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
# this is for verifying the results of different models and make sure they agree with each other
|
||||||
|
|
||||||
|
import os, sys, types
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
try:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
|
||||||
|
os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
|
||||||
|
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
|
||||||
|
|
||||||
|
TOKEN_MODE = 'pile'
|
||||||
|
|
||||||
|
if TOKEN_MODE == 'pile':
|
||||||
|
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
|
||||||
|
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
|
||||||
|
n_layer = 32
|
||||||
|
n_embd = 2560
|
||||||
|
ctx_len = 1024
|
||||||
|
UNKNOWN_CHAR = None
|
||||||
|
|
||||||
|
from src.utils import TOKENIZER
|
||||||
|
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
||||||
|
if TOKEN_MODE == 'pile':
|
||||||
|
tokenizer.vocab_size = 50277
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
os.environ["RWKV_JIT_ON"] = "1"
|
||||||
|
os.environ["RWKV_T_MAX"] = str(ctx_len)
|
||||||
|
|
||||||
|
from src.model_run import RWKV_RNN
|
||||||
|
from src.model import RWKV
|
||||||
|
|
||||||
|
args = types.SimpleNamespace()
|
||||||
|
args.vocab_size = tokenizer.vocab_size
|
||||||
|
args.ctx_len = ctx_len
|
||||||
|
args.n_embd = n_embd
|
||||||
|
args.n_layer = n_layer
|
||||||
|
args.head_qk = 0
|
||||||
|
args.pre_ffn = 0
|
||||||
|
args.grad_cp = 0
|
||||||
|
args.my_pos_emb = 0
|
||||||
|
model_train = RWKV(args).to(RUN_DEVICE)
|
||||||
|
|
||||||
|
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||||
|
model_train = model_train.half()
|
||||||
|
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||||
|
model_train = model_train.bfloat16()
|
||||||
|
|
||||||
|
print('loading ' + MODEL_NAME)
|
||||||
|
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
|
||||||
|
model_train.load_state_dict(m2)
|
||||||
|
|
||||||
|
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||||
|
model_train = model_train.half()
|
||||||
|
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||||
|
model_train = model_train.bfloat16()
|
||||||
|
|
||||||
|
args.MODEL_NAME = MODEL_NAME
|
||||||
|
args.RUN_DEVICE = RUN_DEVICE
|
||||||
|
args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
|
||||||
|
model_rnn = RWKV_RNN(args)
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")
|
||||||
|
|
||||||
|
# context = '\nIn a'
|
||||||
|
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
|
||||||
|
|
||||||
|
if TOKEN_MODE == 'pile':
|
||||||
|
ctx = tokenizer.tokenizer.encode(context)
|
||||||
|
print(f'input len {len(ctx)} data {ctx}')
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
print('\nRWKV-train output')
|
||||||
|
out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
|
||||||
|
print(out, '\n')
|
||||||
|
|
||||||
|
print('\nRWKV-RNN output')
|
||||||
|
state = None
|
||||||
|
out = None
|
||||||
|
src_len = len(ctx)
|
||||||
|
for i in range(src_len):
|
||||||
|
x = ctx[:i+1]
|
||||||
|
out, state = model_rnn.forward(x, state)
|
||||||
|
if i < 3 or i >= src_len - 3:
|
||||||
|
print(out.detach().cpu().numpy())
|
||||||
|
if i == 2:
|
||||||
|
print('...')
|
||||||