@ -53,6 +53,7 @@ def main(
max_batch_size: int = 32,
):
# torch.manual_seed(1)
# torch.set_default_dtype(torch.bfloat16)
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)