diff --git a/example-cpu.py b/example-cpu.py index 7311308..3d915b9 100644 --- a/example-cpu.py +++ b/example-cpu.py @@ -53,6 +53,7 @@ def main( max_batch_size: int = 32, ): # torch.manual_seed(1) + # torch.set_default_dtype(torch.bfloat16) generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)