|
|
|
@ -22,6 +22,8 @@ print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG_TIME = False # True False - show trained time-coeffs
|
|
|
|
DEBUG_TIME = False # True False - show trained time-coeffs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
|
|
|
|
|
|
|
|
|
|
|
|
############################################################################################################
|
|
|
|
############################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_RNN(nn.Module):
|
|
|
|
class RWKV_RNN(nn.Module):
|
|
|
|
@ -41,6 +43,14 @@ class RWKV_RNN(nn.Module):
|
|
|
|
keys = list(w.keys())
|
|
|
|
keys = list(w.keys())
|
|
|
|
print_need_newline = False
|
|
|
|
print_need_newline = False
|
|
|
|
for x in keys:
|
|
|
|
for x in keys:
|
|
|
|
|
|
|
|
block_id = 0
|
|
|
|
|
|
|
|
if 'blocks.' in x:
|
|
|
|
|
|
|
|
block_id = int(x.split('.')[1])
|
|
|
|
|
|
|
|
if 'att.output.weight' in x:
|
|
|
|
|
|
|
|
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
|
|
|
|
|
|
|
|
if 'ffn.value.weight' in x:
|
|
|
|
|
|
|
|
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
|
|
|
|
|
|
|
|
|
|
|
|
if '.time_' in x:
|
|
|
|
if '.time_' in x:
|
|
|
|
w[x] = w[x].squeeze()
|
|
|
|
w[x] = w[x].squeeze()
|
|
|
|
if DEBUG_TIME:
|
|
|
|
if DEBUG_TIME:
|
|
|
|
@ -209,6 +219,9 @@ class RWKV_RNN(nn.Module):
|
|
|
|
ww.time_mix_k, ww.time_mix_r,
|
|
|
|
ww.time_mix_k, ww.time_mix_r,
|
|
|
|
ww.key.weight, ww.value.weight, ww.receptance.weight)
|
|
|
|
ww.key.weight, ww.value.weight, ww.receptance.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (i+1) % RWKV_RESCALE_LAYER == 0:
|
|
|
|
|
|
|
|
x = x / 2
|
|
|
|
|
|
|
|
|
|
|
|
if preprocess_only:
|
|
|
|
if preprocess_only:
|
|
|
|
return state
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|