diff --git a/README.md b/README.md index 0079e50..08c7bed 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,7 @@ c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0) c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float() x = self.head(x) + c ``` +Note: when a token occurs multiple times in the context, it might be better to use max(prob) instead of sum(prob). # The top-a sampling method