|
|
|
|
@ -70,7 +70,6 @@ class Attention(nn.Module):
|
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
# self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
|
|
|
|
|
self.n_local_heads = args.n_heads
|
|
|
|
|
self.head_dim = args.dim // args.n_heads
|
|
|
|
|
|
|
|
|
|
|