diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index 61eb3f8..fd262d5 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -124,7 +124,8 @@ from src.model_run import RWKV_RNN model = RWKV_RNN(args) print(f'\nOptimizing speed...') -model.forward([187], None) +out, _ = model.forward([187], None) +# print(out) gc.collect() torch.cuda.empty_cache() diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index c12fee4..0e0291c 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -22,6 +22,8 @@ print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') DEBUG_TIME = False # True False - show trained time-coeffs +RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer + ############################################################################################################ class RWKV_RNN(nn.Module): @@ -41,6 +43,14 @@ class RWKV_RNN(nn.Module): keys = list(w.keys()) print_need_newline = False for x in keys: + block_id = 0 + if 'blocks.' in x: + block_id = int(x.split('.')[1]) + if 'att.output.weight' in x: + w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) + if 'ffn.value.weight' in x: + w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) + if '.time_' in x: w[x] = w[x].squeeze() if DEBUG_TIME: @@ -208,6 +218,9 @@ class RWKV_RNN(nn.Module): x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i, ww.time_mix_k, ww.time_mix_r, ww.key.weight, ww.value.weight, ww.receptance.weight) + + if (i+1) % RWKV_RESCALE_LAYER == 0: + x = x / 2 if preprocess_only: return state