fix for jit modules

main
BlinkDL 3 years ago
parent 2b4539cd08
commit 50587bd65f

@ -104,52 +104,61 @@ def RUN_CUDA(B, T, C, w, u, k, v):
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
print('\n[--> first run, init model params (very slow for large models) <--]')
print('[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n')
for m in module.modules():
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model
print("\n[--> first run, init model params (very slow for large models) <--]")
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n")
for mm in model.modules():
if "RecursiveScriptModule" in str(type(mm)):
if mm.original_name not in ["Linear"]:
continue
ww = None
for name, param in mm.named_parameters():
if name == "weight":
ww = param
else:
m = mm
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
ww = m.weight
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters(): # find the name of the weight
if id(m.weight) == id(parameter):
name = "[unknown weight]"
for name, parameter in model.named_parameters(): # find the name of the weight
if id(ww) == id(parameter):
break
shape = m.weight.data.shape
shape = ww.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection?
scale = 0.5
if hasattr(m, 'scale_init'):
if hasattr(m, "scale_init"):
scale = m.scale_init
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}")
gain *= scale
if scale == -999:
nn.init.eye_(m.weight)
nn.init.eye_(ww)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(m.weight)
nn.init.zeros_(ww)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
nn.init.orthogonal_(ww, gain=gain)
else:
nn.init.normal_(m.weight, mean=0.0, std=-scale)
nn.init.normal_(ww, mean=0.0, std=-scale)
class RWKV_TimeMix(torch.jit.ScriptModule):

Loading…
Cancel
Save