You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
def top_k_logits(logits, k):
|
|
v, ix = torch.topk(logits, k)
|
|
out = logits.clone()
|
|
out[out < v[:, [-1]]] = -float('Inf')
|
|
return out
|
|
|
|
def top_p_probs(probs, p):
|
|
out = probs.clone()
|
|
|
|
sorted_probs, sorted_indices = torch.sort(out, descending=True)
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
sorted_indices_to_remove = cumulative_probs > p
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
out[indices_to_remove] = 0
|
|
|
|
return out
|
|
|
|
# top-p + top-k + pow&ratio sampling
|
|
def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None):
|
|
logits = logits[:, pos, :] / temperature
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
|
if min_p_ratio is not None:
|
|
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
|
|
logits[probs < limit] = -float('Inf')
|
|
|
|
if top_k is not None:
|
|
logits = top_k_logits(logits, top_k)
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
|
if top_p is not None:
|
|
probs[0] = top_p_probs(probs[0], top_p)
|
|
|
|
ix = torch.multinomial(probs, num_samples=1)
|
|
return ix[0][0].cpu()
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|