|
|
|
|
@ -4,7 +4,7 @@ use crate::tokenizer::{TokenId, Tokenizer};
|
|
|
|
|
use crate::transformer::Transformer;
|
|
|
|
|
use crate::unpickler;
|
|
|
|
|
use clap::Parser;
|
|
|
|
|
use std::io::Read;
|
|
|
|
|
use std::io::{Read, Write};
|
|
|
|
|
|
|
|
|
|
#[derive(Parser)]
|
|
|
|
|
#[command(author, version, about, long_about = None)]
|
|
|
|
|
@ -107,16 +107,21 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
|
|
|
|
|
let highest_pred_idx = token_sampler.sample(&preds);
|
|
|
|
|
toks_id.push(highest_pred_idx as TokenId);
|
|
|
|
|
prev_pos = toks_id.len() - 1;
|
|
|
|
|
|
|
|
|
|
let mut tok_str: String = "".to_string();
|
|
|
|
|
for tok_id in toks_id.iter() {
|
|
|
|
|
for tok_id in toks_id[prev_pos + 1..].iter() {
|
|
|
|
|
if *tok_id == 1 {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
let tok = tok.id_to_str(*tok_id);
|
|
|
|
|
tok_str += tok.replace('▁', " ").as_str();
|
|
|
|
|
if tok == "<0x0A>" {
|
|
|
|
|
tok_str += "\n";
|
|
|
|
|
} else {
|
|
|
|
|
tok_str += tok.replace('▁', " ").as_str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
println!("{}", tok_str);
|
|
|
|
|
print!("{}", tok_str);
|
|
|
|
|
let _ = std::io::stdout().flush();
|
|
|
|
|
prev_pos = toks_id.len() - 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|