Update merge-weights.py

main
randaller 3 years ago committed by GitHub
parent 485c8dbba3
commit 0238fb92ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -60,7 +60,6 @@ def write_model(input_base_path, model_size):
for layer_i in range(n_layers):
if model_size == "7B":
# Unsharded
state_dict |= {
f"layers.{layer_i}.attention.wq.weight": loaded[
f"layers.{layer_i}.attention.wq.weight"
@ -89,7 +88,6 @@ def write_model(input_base_path, model_size):
f"layers.{layer_i}.ffn_norm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
else:
# Sharded
state_dict |= {
f"layers.{layer_i}.attention_norm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
@ -131,7 +129,6 @@ def write_model(input_base_path, model_size):
)
if model_size == "7B":
# Unsharded
state_dict |= {
"tok_embeddings.weight": loaded["tok_embeddings.weight"],
"norm.weight": loaded["norm.weight"],

Loading…
Cancel
Save