rescale to avoid FP16 overflow

main
BlinkDL 3 years ago
parent aef9f6f7ef
commit 2567c8c904

@ -124,7 +124,8 @@ from src.model_run import RWKV_RNN
model = RWKV_RNN(args)
print(f'\nOptimizing speed...')
model.forward([187], None)
out, _ = model.forward([187], None)
# print(out)
gc.collect()
torch.cuda.empty_cache()

@ -22,6 +22,8 @@ print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################
class RWKV_RNN(nn.Module):
@ -41,6 +43,14 @@ class RWKV_RNN(nn.Module):
keys = list(w.keys())
print_need_newline = False
for x in keys:
block_id = 0
if 'blocks.' in x:
block_id = int(x.split('.')[1])
if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if 'ffn.value.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if '.time_' in x:
w[x] = w[x].squeeze()
if DEBUG_TIME:
@ -208,6 +218,9 @@ class RWKV_RNN(nn.Module):
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if (i+1) % RWKV_RESCALE_LAYER == 0:
x = x / 2
if preprocess_only:
return state

Loading…
Cancel
Save