@ -233,6 +233,7 @@ c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = self.head(x) + c
```
Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob).
# The top-a sampling method