From de477314edf6ff863f007fc536fe6e98699570a1 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Mon, 13 Mar 2023 22:33:46 -0700 Subject: [PATCH] Fix newlines not recognized when feeding newlines in the prompt. Tokenizer would misinterpret the newlines. In general, the non-printable control characters don't seem to be tokenized correctly at the moment. I added band-aid for newlines but should maybe fix the others too. --- src/tokenizer.rs | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 3642a83..c836c1e 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -105,6 +105,16 @@ impl Tokenizer { panic!("id out of range"); } + // Tries to find a token from dictionary. + pub fn str_to_id(&self, s: &str) -> Option { + for (piece_str, piece_info) in self.pieces.iter() { + if piece_str == s { + return Some(piece_info.idx as i32); + } + } + None + } + // Converts a string to a Vec<&str> // You may want to use tokenize_to_ids instead. // @@ -121,11 +131,23 @@ impl Tokenizer { let mut best_candidate: &str = ""; let mut best_candidate_len: usize = 0; let mut skip_s: &str = ""; - for (piece_str, _piece_info) in self.pieces.iter() { - if s.starts_with(piece_str) && best_candidate_len < piece_str.len() { - best_candidate = piece_str; - best_candidate_len = piece_str.len(); - skip_s = &s[piece_str.len()..]; + // Specially recognize newline. Otherwise it matches something we don't actually + // want. + if s.starts_with("\n") { + if self.str_to_id("<0x0A>").is_some() { + best_candidate = "<0x0A>"; + best_candidate_len = best_candidate.len(); + skip_s = &s[1..]; + } else { + best_candidate = "\\n"; + } + } else { + for (piece_str, _piece_info) in self.pieces.iter() { + if s.starts_with(piece_str) && best_candidate_len < piece_str.len() { + best_candidate = piece_str; + best_candidate_len = piece_str.len(); + skip_s = &s[piece_str.len()..]; + } } } if best_candidate_len == 0 {