By default, torch uses Float32 precision while running on CPU, which leads, for example, to use 44 GB of RAM for 7B model. We may use Bfloat16 precision on CPU too, which decreases RAM consumption/2, down to 22 GB for 7B model, but inference processing much slower.
By default, torch uses Float32 precision while running on CPU, which leads, for example, to use 44 GB of RAM for 7B model. We may use Bfloat16 precision on CPU too, which decreases RAM consumption/2, down to 22 GB for 7B model, but inference processing much slower.
Uncomment this line in the example-cpu.py or example-chat.py to enable Bfloat16 and save memory.
An optimized checkpoints loader breaks compatibility with Bfloat16, so I decided to add example-bfloat16.py runner.
To use Bfloat16 precision, first you need to unshard checkpoints to a single one.
This will create merged.pth file in the root folder of this repo. Place this file and corresponding params.json of model into [/model] folder. Now you are ready to go.