|
|
|
|
@ -68,7 +68,7 @@ We also propose a new sampling method called top-a (as in src/utils.py):
|
|
|
|
|
|
|
|
|
|
(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2). So it's adaptive, hence "top-a".
|
|
|
|
|
|
|
|
|
|
(3) Feel free to tune the 0.02 and 2 factor.
|
|
|
|
|
(3) Feel free to tune the 0.02 and 2 factor. Tune 0.02 first.
|
|
|
|
|
|
|
|
|
|
The idea of top-a:
|
|
|
|
|
1. If max_prob=0.9, then remove all tokens with prob < 0.0162 (so, removing most alternatives)
|
|
|
|
|
@ -78,7 +78,7 @@ The idea of top-a:
|
|
|
|
|
```
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
limit = torch.pow(torch.max(probs), 2.0) * 0.02
|
|
|
|
|
limit = torch.pow(torch.max(probs), 2) * 0.02
|
|
|
|
|
logits[probs < limit] = -float('Inf')
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|