diff --git a/hf-inference-example.py b/hf-inference-example.py index 5ce2558..faacf47 100644 --- a/hf-inference-example.py +++ b/hf-inference-example.py @@ -1,5 +1,9 @@ import llamahf +# to save memory use bfloat16 on cpu +# import torch +# torch.set_default_dtype(torch.bfloat16) + MODEL = 'decapoda-research/llama-7b-hf' # MODEL = 'decapoda-research/llama-13b-hf' # MODEL = 'decapoda-research/llama-30b-hf'