You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rllama/src/tokenizer.rs

157 lines
4.9 KiB
Rust

use crate::protomodels::sentencepiece_model::model_proto::sentence_piece;
use crate::protomodels::sentencepiece_model::ModelProto;
use protobuf::Message;
use std::collections::BTreeMap;
use std::io::Read;
use std::path::Path;
use thiserror::Error;
pub type TokenId = i32;
#[derive(Clone, Debug)]
pub struct Tokenizer {
pieces: BTreeMap<String, Piece>,
}
#[derive(Clone, Debug, Copy, Eq, Ord, PartialEq, PartialOrd)]
pub enum PieceType {
Normal,
Unknown,
Control,
UserDefined,
Byte,
Unused,
}
#[derive(Clone, Debug)]
pub struct Piece {
_tp: PieceType,
// piece: String this is in the BTreeMap that holds the pieces
_score: f32,
idx: usize,
}
#[derive(Error, Debug)]
pub enum TokenizerError {
#[error("IO error")]
IoError(#[from] std::io::Error),
#[error("Protobuf error")]
ProtobufError(#[from] protobuf::Error),
#[error("Unknown piece type")]
UnknownPieceType(String),
}
impl Tokenizer {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Tokenizer, TokenizerError> {
let mut fs = std::fs::File::open(path)?;
let mut buffer = Vec::new();
fs.read_to_end(&mut buffer)?;
std::mem::drop(fs);
let model = ModelProto::parse_from_bytes(&buffer)?;
let mut pieces = BTreeMap::new();
for (idx, piece) in model.pieces.iter().enumerate() {
let piece_str = piece.piece.clone();
if piece_str.is_none() {
continue;
}
let piece_str = piece_str.unwrap();
let piece_type = match piece.type_ {
None => sentence_piece::Type::NORMAL,
Some(v) => match v.enum_value() {
Err(_) => return Err(TokenizerError::UnknownPieceType(piece_str)),
Ok(v) => v,
},
};
let score = piece.score.unwrap_or(0.0);
let tp = if piece_type == sentence_piece::Type::NORMAL {
PieceType::Normal
} else if piece_type == sentence_piece::Type::UNKNOWN {
PieceType::Unknown
} else if piece_type == sentence_piece::Type::CONTROL {
PieceType::Control
} else if piece_type == sentence_piece::Type::USER_DEFINED {
PieceType::UserDefined
} else if piece_type == sentence_piece::Type::BYTE {
PieceType::Byte
} else if piece_type == sentence_piece::Type::UNUSED {
PieceType::Unused
} else {
return Err(TokenizerError::UnknownPieceType(piece_str));
};
pieces.insert(
piece_str,
Piece {
_tp: tp,
_score: score,
idx,
},
);
}
Ok(Tokenizer { pieces })
}
// Gives a string for a token id.
// Panics if the id is out of range.
pub fn id_to_str(&self, id: i32) -> &str {
let id = id as usize;
for (piece_str, piece_info) in self.pieces.iter() {
if piece_info.idx == id {
return piece_str;
}
}
panic!("id out of range");
}
// Converts a string to a Vec<&str>
// You may want to use tokenize_to_ids instead.
//
// This will not add start or end tokens; only the string is processed.
//
// I noticed LLaMa code adds an extra space character at the beginning of the string, this
// function does not do that either.
pub fn tokenize_to_pieces<S: AsRef<str>>(&self, s: S) -> Vec<&str> {
let mut s: &str = s.as_ref();
let mut result: Vec<&str> = Vec::new();
// Very naive matching
while !s.is_empty() {
let mut best_candidate: &str = "";
let mut best_candidate_len: usize = 0;
let mut skip_s: &str = "";
for (piece_str, _piece_info) in self.pieces.iter() {
if s.starts_with(piece_str) && best_candidate_len < piece_str.len() {
best_candidate = piece_str;
best_candidate_len = piece_str.len();
skip_s = &s[piece_str.len()..];
}
}
if best_candidate_len == 0 {
// Skip token.
s = s.get(1..).unwrap_or("");
} else {
result.push(best_candidate);
s = skip_s;
}
}
result
}
pub fn tokenize_to_ids<S: AsRef<str>>(&self, s: S) -> Vec<TokenId> {
let mut s: String = format!("▁{}", s.as_ref());
// Replace all space characters with a special token.
s = s.replace(' ', "▁");
let pieces = self.tokenize_to_pieces(s);
let mut result = Vec::new();
result.push(1); // start token
for piece in pieces {
let piece_info = self.pieces.get(piece).unwrap();
result.push(piece_info.idx as i32);
}
result
}
}