diff --git a/llama/model.py b/llama/model.py index c1c6ddd..6e7c210 100755 --- a/llama/model.py +++ b/llama/model.py @@ -70,7 +70,6 @@ class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - # self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size() self.n_local_heads = args.n_heads self.head_dim = args.dim // args.n_heads