@ -70,7 +70,6 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads
self.head_dim = args.dim // args.n_heads