diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index 94dc0cd..cbda400 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -104,52 +104,61 @@ def RUN_CUDA(B, T, C, w, u, k, v): # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## -def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module - print('\n[--> first run, init model params (very slow for large models) <--]') - print('[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n') - for m in module.modules(): - if not isinstance(m, (nn.Linear, nn.Embedding)): - continue +def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model + print("\n[--> first run, init model params (very slow for large models) <--]") + print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n") + + for mm in model.modules(): + if "RecursiveScriptModule" in str(type(mm)): + if mm.original_name not in ["Linear"]: + continue + ww = None + for name, param in mm.named_parameters(): + if name == "weight": + ww = param + else: + m = mm + if not isinstance(m, (nn.Linear, nn.Embedding)): + continue + ww = m.weight with torch.no_grad(): - name = '[unknown weight]' - for name, parameter in module.named_parameters(): # find the name of the weight - if id(m.weight) == id(parameter): + name = "[unknown weight]" + for name, parameter in model.named_parameters(): # find the name of the weight + if id(ww) == id(parameter): break - shape = m.weight.data.shape + shape = ww.shape gain = 1.0 scale = 1.0 # extra scale for gain if isinstance(m, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? + if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb? scale = 1e-4 else: scale = 0 if isinstance(m, nn.Linear): - if m.bias is not None: - m.bias.data.zero_() if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? + if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection? scale = 0.5 - if hasattr(m, 'scale_init'): + if hasattr(m, "scale_init"): scale = m.scale_init - # print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) + # print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}") gain *= scale if scale == -999: - nn.init.eye_(m.weight) + nn.init.eye_(ww) elif gain == 0: # zero init is great for some RWKV matrices - nn.init.zeros_(m.weight) + nn.init.zeros_(ww) elif gain > 0: - nn.init.orthogonal_(m.weight, gain=gain) + nn.init.orthogonal_(ww, gain=gain) else: - nn.init.normal_(m.weight, mean=0.0, std=-scale) + nn.init.normal_(ww, mean=0.0, std=-scale) class RWKV_TimeMix(torch.jit.ScriptModule):