You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RWKV-LM/src/utils.py

51 lines
1.5 KiB
Python

import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
def top_p_probs(probs, p):
out = probs.clone()
sorted_probs, sorted_indices = torch.sort(out, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
out[indices_to_remove] = 0
return out
# top-p + top-k + pow&ratio sampling
def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None):
logits = logits[:, pos, :] / temperature
probs = F.softmax(logits, dim=-1)
if min_p_ratio is not None:
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = -float('Inf')
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
if top_p is not None:
probs[0] = top_p_probs(probs[0], top_p)
ix = torch.multinomial(probs, num_samples=1)
return ix[0][0].cpu()
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)