|
|
|
@ -233,6 +233,7 @@ c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
|
|
|
|
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
|
|
|
|
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
|
|
|
|
x = self.head(x) + c
|
|
|
|
x = self.head(x) + c
|
|
|
|
```
|
|
|
|
```
|
|
|
|
|
|
|
|
Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob).
|
|
|
|
|
|
|
|
|
|
|
|
# The top-a sampling method
|
|
|
|
# The top-a sampling method
|
|
|
|
|
|
|
|
|
|
|
|
|