|
|
|
|
@ -104,52 +104,61 @@ def RUN_CUDA(B, T, C, w, u, k, v):
|
|
|
|
|
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
|
|
|
|
|
print('\n[--> first run, init model params (very slow for large models) <--]')
|
|
|
|
|
print('[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n')
|
|
|
|
|
for m in module.modules():
|
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
|
continue
|
|
|
|
|
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model
|
|
|
|
|
print("\n[--> first run, init model params (very slow for large models) <--]")
|
|
|
|
|
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n")
|
|
|
|
|
|
|
|
|
|
for mm in model.modules():
|
|
|
|
|
if "RecursiveScriptModule" in str(type(mm)):
|
|
|
|
|
if mm.original_name not in ["Linear"]:
|
|
|
|
|
continue
|
|
|
|
|
ww = None
|
|
|
|
|
for name, param in mm.named_parameters():
|
|
|
|
|
if name == "weight":
|
|
|
|
|
ww = param
|
|
|
|
|
else:
|
|
|
|
|
m = mm
|
|
|
|
|
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
|
|
|
continue
|
|
|
|
|
ww = m.weight
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
name = '[unknown weight]'
|
|
|
|
|
for name, parameter in module.named_parameters(): # find the name of the weight
|
|
|
|
|
if id(m.weight) == id(parameter):
|
|
|
|
|
name = "[unknown weight]"
|
|
|
|
|
for name, parameter in model.named_parameters(): # find the name of the weight
|
|
|
|
|
if id(ww) == id(parameter):
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
shape = m.weight.data.shape
|
|
|
|
|
shape = ww.shape
|
|
|
|
|
gain = 1.0
|
|
|
|
|
scale = 1.0 # extra scale for gain
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
|
|
|
gain = math.sqrt(max(shape[0], shape[1]))
|
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
|
|
|
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb?
|
|
|
|
|
scale = 1e-4
|
|
|
|
|
else:
|
|
|
|
|
scale = 0
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
if shape[0] > shape[1]:
|
|
|
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
|
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
|
|
|
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection?
|
|
|
|
|
scale = 0.5
|
|
|
|
|
|
|
|
|
|
if hasattr(m, 'scale_init'):
|
|
|
|
|
if hasattr(m, "scale_init"):
|
|
|
|
|
scale = m.scale_init
|
|
|
|
|
|
|
|
|
|
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
|
|
|
|
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}")
|
|
|
|
|
|
|
|
|
|
gain *= scale
|
|
|
|
|
if scale == -999:
|
|
|
|
|
nn.init.eye_(m.weight)
|
|
|
|
|
nn.init.eye_(ww)
|
|
|
|
|
elif gain == 0:
|
|
|
|
|
# zero init is great for some RWKV matrices
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
nn.init.zeros_(ww)
|
|
|
|
|
elif gain > 0:
|
|
|
|
|
nn.init.orthogonal_(m.weight, gain=gain)
|
|
|
|
|
nn.init.orthogonal_(ww, gain=gain)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.normal_(m.weight, mean=0.0, std=-scale)
|
|
|
|
|
nn.init.normal_(ww, mean=0.0, std=-scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV_TimeMix(torch.jit.ScriptModule):
|
|
|
|
|
|