From 038f06b99627049729e332bc66159e0a0092cdaa Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 3 Feb 2023 04:58:11 +0000 Subject: [PATCH] rwkv-4b --- RWKV-v4neo/src/model.py | 136 +++++++++++++++++++++++----------------- 1 file changed, 80 insertions(+), 56 deletions(-) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 3857e03..9b23fbb 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -135,42 +135,89 @@ class RWKV_TimeMix(MyModule): self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.key = nn.Linear(args.n_embd, attn_sz, bias=False) self.value = nn.Linear(args.n_embd, attn_sz, bias=False) self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False) - self.output = nn.Linear(attn_sz, args.n_embd, bias=False) - # if self.my_testing > 0: - # self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd)) - - @MyFunction - def jit_func(self, x): - - # Mix x with the previous timestep to produce xk, xv, xr - xx = self.time_shift(x) - xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) - xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) - xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) - - # Use xk, xv, xr to produce k, v, r - k = self.key(xk) - v = self.value(xv) - r = self.receptance(xr) - sr = torch.sigmoid(r) - - return sr, k, v - - def forward(self, x): - B, T, C = x.size() # x = (Batch,Time,Channel) - - sr, k, v = self.jit_func(x) - - rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) - rwkv = self.output(rwkv) - return rwkv + if 'a' in os.environ["RWKV_MY_TESTING"]: + self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + d_qkv = args.n_embd // 16 + self.qq = nn.Linear(args.n_embd, d_qkv, bias=False) + self.kk = nn.Linear(args.n_embd, d_qkv, bias=False) + self.vv = nn.Linear(args.n_embd, d_qkv, bias=False) + self.oo = nn.Linear(d_qkv, args.n_embd, bias=False) + with torch.no_grad(): + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd + self.time_mix_qq = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_kk = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_vv = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + + if 'a' not in os.environ["RWKV_MY_TESTING"]: + @MyFunction + def jit_func(self, x): + + # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + # Use xk, xv, xr to produce k, v, r + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + return sr, k, v + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + sr, k, v = self.jit_func(x) + rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + return self.output(rwkv) + + if 'a' in os.environ["RWKV_MY_TESTING"]: + @MyFunction + def QKV(self, q, k, v): + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.att_mask == 0, float('-inf')) + att = F.softmax(att, dim = -1) + x = att @ v + return x + + @MyFunction + def jit_funcQKV(self, x): + # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq) + xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk) + xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv) + + # Use xk, xv, xr to produce k, v, r + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + + qq = self.qq(xqq) + kk = self.kk(xkk) + vv = self.vv(xvv) + + return sr, k, v, qq, kk, vv + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + sr, k, v, qq, kk, vv = self.jit_funcQKV(x) + rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) + return rwkv +######################################################################################################## class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): @@ -195,38 +242,15 @@ class RWKV_ChannelMix(MyModule): self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) self.value = nn.Linear(hidden_sz, args.n_embd, bias=False) - # if self.my_testing in [1]: - # self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz)) - # elif self.my_testing in [2]: - # self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd)) - - @MyFunction def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) - k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) - - rkv = torch.sigmoid(self.receptance(xr)) * kv - return rkv - - # k = self.key(xk) - # # if self.my_testing in [0, 2]: - # k = torch.square(torch.relu(k)) - # # elif self.my_testing == 1: - # # k = torch.square(torch.relu(k)) + k * self.aaa - # kv = self.value(k) - # r = self.receptance(xr) - # # if self.my_testing == 0: - # r = torch.sigmoid(r) - # # elif self.my_testing == 2: - # # r = torch.sigmoid(r) + r * self.aaa - # rkv = r * kv - # return rkv + return torch.sigmoid(self.receptance(xr)) * kv ######################################################################################################## # The RWKV Model with our blocks @@ -479,7 +503,7 @@ class RWKV(pl.LightningModule): gain = 1.0 scale = 1.0 - if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n: + if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: m[n] = p else: if n == "emb.weight": @@ -487,7 +511,7 @@ class RWKV(pl.LightningModule): else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]: + for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']: if kk in n: scale = 0 if n == "head.weight":