From b9be485610edfbb3bfc97633bfa72c436f7a5516 Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Mon, 20 Mar 2023 18:24:02 -0700 Subject: [PATCH] Add simple HTTP API support. It took annoyingly a lot of effort just to make this simple server. I tried rouille web framework first, but it didn't support getting chunked output to the client line-by-line. (seems that if it exposed more details about the underlying tiny-http package I could have hacked it to work). I went with Rocket because it had less async stuff and seemed decent. I got weird issues where it seemed as if memory use kept increasing and increasing. I may have got that fixed but I couldn't figure out what made it use so much memory, even tools like valgrind and heaptrack told me there isn't that much memory allocated but I can see RES increasing in `htop`. Switched to MiMalloc as it seems to slightly decrease memory use. Added details about the inference server to README.md. And also added an example Python script of it. I want to use this feature to later investigate how much do quantizations or f16/f32 affect output. Easier to do such things on Python. --- Cargo.lock | 661 ++++++++++++++++++++++++++++++++++-- Cargo.toml | 5 + README.md | 86 +++++ examples/api_hello_world.py | 25 ++ src/lib.rs | 5 + src/main.rs | 5 + src/rllama_main.rs | 513 ++++++++++++++++++++++++++-- src/tensor.rs | 17 +- src/token_sampler.rs | 15 + src/transformer.rs | 57 +++- 10 files changed, 1335 insertions(+), 54 deletions(-) create mode 100755 examples/api_hello_world.py diff --git a/Cargo.lock b/Cargo.lock index 2eb5405..a15feda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,60 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aead" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fc95d1bdb8e6666b2b217308eeeb09f2d6728d104be3e31916cc74d15420331" +dependencies = [ + "generic-array", +] + +[[package]] +name = "aes" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "884391ef1066acaa41e766ba8f596341b96e93ce34f9a43e7d24bf0a0eaf0561" +dependencies = [ + "aes-soft", + "aesni", + "cipher", +] + +[[package]] +name = "aes-gcm" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5278b5fabbb9bd46e24aa69b2fdea62c99088e0a950a9be40e3e0101298f88da" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + +[[package]] +name = "aes-soft" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be14c7498ea50828a38d0e24a765ed2effe92a705885b57d029cd67d45744072" +dependencies = [ + "cipher", + "opaque-debug", +] + +[[package]] +name = "aesni" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea2e11f5e94c2f7d386164cc2aa1f97823fed6f259e486940a71c174dd01b0ce" +dependencies = [ + "cipher", + "opaque-debug", +] + [[package]] name = "aho-corasick" version = "0.7.20" @@ -49,18 +103,49 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "base64" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "489d6c0ed21b11d038c31b6ceccca973e65d73ba3bd8ecb9a2babf5546164643" +dependencies = [ + "byteorder", + "safemem", +] + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "cast" version = "0.3.0" @@ -106,6 +191,15 @@ dependencies = [ "half 1.8.2", ] +[[package]] +name = "cipher" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12f8e7987cbd042a63249497f41aed09f8e65add917ea6566effbc56578d6801" +dependencies = [ + "generic-array", +] + [[package]] name = "cl-sys" version = "0.4.2" @@ -150,8 +244,8 @@ checksum = "fddf67631444a3a3e3e5ac51c36a5e01335302de677bd78759eaa90ab1f46644" dependencies = [ "heck", "proc-macro-error", - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "syn 1.0.109", ] @@ -197,6 +291,37 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "cookie" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be2018768ed1d848cc4d347d551546474025ba820e5db70e4c9aaa349f678bd7" +dependencies = [ + "aes-gcm", + "base64 0.13.1", + "hkdf", + "hmac", + "percent-encoding 2.2.0", + "rand", + "sha2", + "time", +] + +[[package]] +name = "cpufeatures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" +dependencies = [ + "libc", +] + +[[package]] +name = "cpuid-bool" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcb25d077389e53838a8158c8e99174c5a9d902dee4904320db714f3c653ffba" + [[package]] name = "criterion" version = "0.4.0" @@ -306,6 +431,66 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "crypto-mac" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bff07008ec701e8028e2ceb8f83f0e4274ee62bd2dbdc4fefff2e9a91824081a" +dependencies = [ + "generic-array", + "subtle", +] + +[[package]] +name = "ctr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb4a30d54f7443bf3d6191dcd486aca19e67cb3c49fa7a06a319966346707e7f" +dependencies = [ + "cipher", +] + +[[package]] +name = "devise" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd716c4a507adc5a2aa7c2a372d06c7497727e0892b243d3036bc7478a13e526" +dependencies = [ + "devise_codegen", + "devise_core", +] + +[[package]] +name = "devise_codegen" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea7b8290d118127c08e3669da20b331bed56b09f20be5945b7da6c116d8fab53" +dependencies = [ + "devise_core", + "quote 0.6.13", +] + +[[package]] +name = "devise_core" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1053e9d5d5aade9bcedb5ab53b78df2b56ff9408a3138ce77eaaef87f932373" +dependencies = [ + "bitflags", + "proc-macro2 0.4.30", + "quote 0.6.13", + "syn 0.15.44", +] + +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + [[package]] name = "either" version = "1.8.1" @@ -387,6 +572,16 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" +[[package]] +name = "generic-array" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" +dependencies = [ + "typenum", + "version_check 0.9.4", +] + [[package]] name = "getrandom" version = "0.2.8" @@ -395,9 +590,25 @@ checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", ] +[[package]] +name = "ghash" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97304e4cd182c3846f7575ced3890c53012ce534ad9114046b0a9e00bb30a375" +dependencies = [ + "opaque-debug", + "polyval", +] + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "half" version = "1.8.2" @@ -449,6 +660,62 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "hkdf" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51ab2f639c231793c5f6114bdb9bbe50a7dbbfcd7c7c6bd8475dec2d991e964f" +dependencies = [ + "digest", + "hmac", +] + +[[package]] +name = "hmac" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15" +dependencies = [ + "crypto-mac", + "digest", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "hyper" +version = "0.10.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a0652d9a2609a968c14be1a9ea00bf4b1d64e2e1f53a1b51b6fff3a6e829273" +dependencies = [ + "base64 0.9.3", + "httparse", + "language-tags", + "log 0.3.9", + "mime", + "num_cpus", + "time", + "traitobject", + "typeable", + "unicase", + "url", +] + +[[package]] +name = "idna" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f09e0f0b1fb55fdee1f17470ad800da77af5186a1a76c026b679358b7e844e" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "1.9.2" @@ -527,6 +794,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "language-tags" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a91d884b6667cd606bb5a69aa0c99ba811a115fc68915e7056ec08a46e93199a" + [[package]] name = "lazy_static" version = "1.4.0" @@ -539,12 +812,31 @@ version = "0.2.140" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" +[[package]] +name = "libmimalloc-sys" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8c7cbf8b89019683667e347572e6d55a7df7ea36b0c4ce69961b0cde67b174" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "linux-raw-sys" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +[[package]] +name = "log" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e19e8d5c34a3e0e2223db8e060f9e8264aeeb5c5fc64a4ee9965c062211c024b" +dependencies = [ + "log 0.4.17", +] + [[package]] name = "log" version = "0.4.17" @@ -554,6 +846,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "matches" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" + [[package]] name = "memchr" version = "2.5.0" @@ -569,6 +867,24 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mimalloc" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dcb174b18635f7561a0c6c9fc2ce57218ac7523cf72c50af80e2d79ab8f3ba1" +dependencies = [ + "libmimalloc-sys", +] + +[[package]] +name = "mime" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba626b8a6de5da682e1caa06bdb42a335aee5a84db8e5046a3e8ab17ba0a3ae0" +dependencies = [ + "log 0.3.9", +] + [[package]] name = "nodrop" version = "0.1.14" @@ -669,12 +985,52 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "os_str_bytes" version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" +[[package]] +name = "pear" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32dfa7458144c6af7f9ce6a137ef975466aa68ffa44d4d816ee5934018ba960a" +dependencies = [ + "pear_codegen", +] + +[[package]] +name = "pear_codegen" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0288ba5d581afbc93e2bbd931c1013584c15ecf46b1cdb927edc7abddbc8ca6" +dependencies = [ + "proc-macro2 0.4.30", + "quote 0.6.13", + "syn 0.15.44", + "version_check 0.9.4", + "yansi", +] + +[[package]] +name = "percent-encoding" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831" + +[[package]] +name = "percent-encoding" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" + [[package]] name = "plotters" version = "0.3.4" @@ -703,6 +1059,17 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "polyval" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eebcc4aa140b9abd2bc40d9c3f7ccec842679cd79045ac3a7ac698c1a064b7cd" +dependencies = [ + "cpuid-bool", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "0.3.19" @@ -722,10 +1089,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" dependencies = [ "proc-macro-error-attr", - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "syn 1.0.109", - "version_check", + "version_check 0.9.4", ] [[package]] @@ -734,9 +1101,18 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" dependencies = [ - "proc-macro2", - "quote", - "version_check", + "proc-macro2 1.0.52", + "quote 1.0.26", + "version_check 0.9.4", +] + +[[package]] +name = "proc-macro2" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759" +dependencies = [ + "unicode-xid", ] [[package]] @@ -782,7 +1158,7 @@ checksum = "9d39b14605eaa1f6a340aec7f320b34064feb26c93aec35d6a9a2272a8ddfa49" dependencies = [ "anyhow", "indexmap", - "log", + "log 0.4.17", "protobuf", "protobuf-support", "tempfile", @@ -799,13 +1175,22 @@ dependencies = [ "thiserror", ] +[[package]] +name = "quote" +version = "0.6.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1" +dependencies = [ + "proc-macro2 0.4.30", +] + [[package]] name = "quote" version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ - "proc-macro2", + "proc-macro2 1.0.52", ] [[package]] @@ -907,6 +1292,8 @@ dependencies = [ "embedded-profiling", "half 2.2.1", "indicatif", + "lazy_static", + "mimalloc", "num-complex", "ocl", "protobuf", @@ -914,11 +1301,65 @@ dependencies = [ "protobuf-parse", "rand", "rayon", + "rocket", "serde", "serde_json", "thiserror", ] +[[package]] +name = "rocket" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83b9d9dc08c5dcc1d8126a9dd615545e6a358f8c13c883c8dfed8c0376fa355e" +dependencies = [ + "atty", + "base64 0.13.1", + "log 0.4.17", + "memchr", + "num_cpus", + "pear", + "rocket_codegen", + "rocket_http", + "state", + "time", + "toml", + "version_check 0.9.4", + "yansi", +] + +[[package]] +name = "rocket_codegen" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2810037b5820098af97bd4fdd309e76a8101ceb178147de775c835a2537284fe" +dependencies = [ + "devise", + "glob", + "indexmap", + "quote 0.6.13", + "rocket_http", + "version_check 0.9.4", + "yansi", +] + +[[package]] +name = "rocket_http" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf9cbd128e1f321a2d0bebd2b7cf0aafd89ca43edf69e49b56a5c46e48eb19f" +dependencies = [ + "cookie", + "hyper", + "indexmap", + "pear", + "percent-encoding 1.0.1", + "smallvec", + "state", + "time", + "unicode-xid", +] + [[package]] name = "rustc_version" version = "0.4.0" @@ -948,6 +1389,12 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "safemem" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072" + [[package]] name = "same-file" version = "1.0.6" @@ -984,8 +1431,8 @@ version = "1.0.157" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78997f4555c22a7971214540c4a661291970619afd56de19f77e0de86296e1e5" dependencies = [ - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "syn 2.0.0", ] @@ -1000,20 +1447,62 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +dependencies = [ + "block-buffer", + "cfg-if", + "cpufeatures", + "digest", + "opaque-debug", +] + +[[package]] +name = "smallvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" + +[[package]] +name = "state" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3015a7d0a5fd5105c91c3710d42f9ccf0abfb287d62206484dcc67f9569a6483" + [[package]] name = "strsim" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "subtle" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" + +[[package]] +name = "syn" +version = "0.15.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ca4b3b69a77cbe1ffc9e198781b7acb0c7365a883670e8f1c1bc66fba79a5c5" +dependencies = [ + "proc-macro2 0.4.30", + "quote 0.6.13", + "unicode-xid", +] + [[package]] name = "syn" version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "unicode-ident", ] @@ -1023,8 +1512,8 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cff13bb1732bccfe3b246f3fdb09edfd51c01d6f5299b7ccd9457c2e4e37774" dependencies = [ - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "unicode-ident", ] @@ -1071,11 +1560,22 @@ version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "syn 2.0.0", ] +[[package]] +name = "time" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" +dependencies = [ + "libc", + "wasi 0.10.0+wasi-snapshot-preview1", + "winapi", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -1086,18 +1586,117 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "toml" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758664fc71a3a69038656bee8b6be6477d2a6c315a6b81f7081f591bffa4111f" +dependencies = [ + "serde", +] + +[[package]] +name = "traitobject" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079" + +[[package]] +name = "typeable" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1410f6f91f21d1612654e7cc69193b0334f909dcf2c790c4826254fbb86f8887" + +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + +[[package]] +name = "unicase" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4765f83163b74f957c797ad9253caf97f103fb064d3999aea9568d09fc8a33" +dependencies = [ + "version_check 0.1.5", +] + +[[package]] +name = "unicode-bidi" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d502c968c6a838ead8e69b2ee18ec708802f99db92a0d156705ec9ef801993b" + [[package]] name = "unicode-ident" version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" +[[package]] +name = "unicode-normalization" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-width" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +[[package]] +name = "unicode-xid" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" + +[[package]] +name = "universal-hash" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f214e8f697e925001e66ec2c6e37a4ef93f0f78c2eed7814394e10c62025b05" +dependencies = [ + "generic-array", + "subtle", +] + +[[package]] +name = "url" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd4e7c0d531266369519a4aa4f399d748bd37043b00bde1e4ff1f60a120b355a" +dependencies = [ + "idna", + "matches", + "percent-encoding 1.0.1", +] + +[[package]] +name = "version_check" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd" + [[package]] name = "version_check" version = "0.9.4" @@ -1114,6 +1713,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "wasi" +version = "0.10.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1137,10 +1742,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" dependencies = [ "bumpalo", - "log", + "log 0.4.17", "once_cell", - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "syn 1.0.109", "wasm-bindgen-shared", ] @@ -1151,7 +1756,7 @@ version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" dependencies = [ - "quote", + "quote 1.0.26", "wasm-bindgen-macro-support", ] @@ -1161,8 +1766,8 @@ version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ - "proc-macro2", - "quote", + "proc-macro2 1.0.52", + "quote 1.0.26", "syn 1.0.109", "wasm-bindgen-backend", "wasm-bindgen-shared", @@ -1306,3 +1911,9 @@ name = "windows_x86_64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" diff --git a/Cargo.toml b/Cargo.toml index ee5c44d..3074cfa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,10 +32,14 @@ indicatif = "0.17" colored = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" +mimalloc = "0.1" ocl = { version = "0.19", optional = true } +rocket = { version = "0.4", features = ["sse"], optional = true } +lazy_static = "1.4" [features] opencl = ["ocl"] +server = ["rocket"] # We need protobuf compiler [build-dependencies] @@ -46,6 +50,7 @@ protobuf-parse = "3.2" criterion = "0.4" [profile.release] +panic = 'abort' debug = true [[bench]] diff --git a/README.md b/README.md index fb94f5e..99ff638 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,92 @@ rllama --tokenizer-model /path/to/tokenizer.model \ Use `rllama --help` to see all the options. +## Inference server + +`rllama` can run in an inference server mode with a simple HTTP JSON API. + +The command line flags for this are: + + * `--inference-server` using this will turn on the inference server. + * `--inference-server-port` sets the port. Default port is 8080. + * `--inference-server-host` sets the host. The default host is 127.0.0.1. + * `--inference-server-max-concurrent-inferences` sets how many concurrent + requests are allowed to be actively doing inference at the same time. The + default is 5. + * `--inference-server-api-path` sets which path servers the API requests. The + default path is `/rllama/v1/inference` + * `--inference-server-prompt-cache-size` sets how many previous prompt + calculations should be cached. Default is 1000. This speeds up token + generation for prompts that were already requested before. + +Prompts and flags related to token sampling are all ignored in inference server +mode. Instead, they are obtained from each HTTP JSON API request. + +### Inference server API + +There is an `examples/api_hello_world.py` for a minimal API use example. + +``` +POST /rllama/v1/inference +``` + +Expects a JSON body and `Accept: application/json` or `Accept: text/jsonl`. + +The expected JSON is as follows: + +```json + { + "temperature": + "top_k": + "top_p": + "repetition_penalty": + "stop_at_end_token": + "max_seq_len": + "max_new_tokens": + "no_token_sampling": + "prompt": + } +``` + +The form of the response depends on if `no_token_sampling` is set to true or false. The +response is in JSONL, i.e. multiple JSON dictionaries, separated by newlines. + +`no_token_sampling` can turn off `rllama`'s own token sampling. In this case, +the probabilities for every token are returned instead. + +When no\_token\_sampling = false: + +```json +{: {"p": , "is_end_token": bool, might not be present}} +``` + + * `token` contains the new token to be appended to output. It does not + include string you fed to the system originally. + * `p` is the probability that this token was chosen. For example, if this + value is 0.1, it means that this particular token had 10% chance of being + selected with the current token sampling settings. + * `is_end_token` is `true` is the given token signifies end of output. This + field is not present otherwise. + +When no\_token\_sampling = true: + +```json +{: {"p": , "is_end_token": bool, might not be present} \ +,: {"p": , "is_end_token": bool, might not be present} \ +,...} +``` + +Tokens where `p = 0` will not be present in the JSON output. + +If you want to implement your own token sampling, you may want to set +`max_new_tokens=1` and `stop_at_end_token=false` to suppress rllama's own +sampling behavior entirely. + +`rllama` internally caches recently queried prompts and the intermediate +computations so that it's able to continue off quickly if you issue a query +that is either the same as a previous query or a continuation of one. + ## How to turn on OpenCL Use `opencl` Cargo feature. diff --git a/examples/api_hello_world.py b/examples/api_hello_world.py new file mode 100755 index 0000000..a7b54e4 --- /dev/null +++ b/examples/api_hello_world.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +""" +This script uses the rllama API to generate tokens. + +It does not print the tokens nicely. +""" + +import requests + +def main(): + url = 'http://127.0.0.1:8080/rllama/v1/inference' + req = { + 'prompt': 'Hello world!', + 'max_seq_len': 1024, + 'max_new_tokens': 200, + 'no_token_sampling': False + } + res = requests.post(url, json=req, stream=True) + for line in res.iter_lines(): + print(line.decode('utf-8')) + + +if __name__ == '__main__': + main() diff --git a/src/lib.rs b/src/lib.rs index 2deb2f4..3232002 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ #![feature(stdsimd)] +#![feature(decl_macro)] pub mod embedding; pub mod protomodels; pub mod rllama_main; +pub mod semaphore; pub mod simd_support; pub mod tensor; #[cfg(feature = "opencl")] @@ -11,3 +13,6 @@ pub mod token_sampler; pub mod tokenizer; pub mod transformer; pub mod unpickler; +#[cfg(feature = "server")] +#[macro_use] +extern crate rocket; diff --git a/src/main.rs b/src/main.rs index c612460..8b7bd52 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,11 @@ compile_error!("This library assumes availability of AVX and must be compiled wi #[cfg(not(target_feature = "avx"))] compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2"); +use mimalloc::MiMalloc; + +#[global_allocator] +static GLOBAL: MiMalloc = MiMalloc; + pub fn main() -> Result<(), Box> { rllama::rllama_main::main() } diff --git a/src/rllama_main.rs b/src/rllama_main.rs index 7a33760..45a2228 100644 --- a/src/rllama_main.rs +++ b/src/rllama_main.rs @@ -1,18 +1,23 @@ use crate::embedding::Embedding; +use crate::semaphore::Semaphore; #[cfg(feature = "opencl")] use crate::tensor_opencl_support::OpenCL; use crate::token_sampler::TokenSampler; use crate::tokenizer::{TokenId, Tokenizer}; -use crate::transformer::{DataSettings, Transformer}; +use crate::transformer::{DataSettings, Transformer, TransformerCaches}; use crate::unpickler; use crate::unpickler::Value; use clap::Parser; use colored::Colorize; +#[cfg(feature = "server")] +use rocket::{response::status, response::Stream, Data, State}; use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; use std::io::{Read, Write}; use std::path::PathBuf; +use std::sync::{Arc, RwLock}; -#[derive(Parser)] +#[derive(Parser, Clone)] #[command(author, version, about, long_about = None)] struct Cli { #[arg(long)] @@ -51,6 +56,24 @@ struct Cli { #[cfg(feature = "opencl")] #[arg(long)] opencl_device: Option, + + #[arg(long, action)] + inference_server: bool, + + #[arg(long)] + inference_server_port: Option, + + #[arg(long)] + inference_server_host: Option, + + #[arg(long)] + inference_server_max_concurrent_inferences: Option, + + #[arg(long)] + inference_server_api_path: Option, + + #[arg(long)] + inference_server_prompt_cache_size: Option, } #[derive(Clone, Serialize, Deserialize)] @@ -65,9 +88,15 @@ struct ModelParams { pub fn main() -> Result<(), Box> { let cli = Cli::parse(); - let model_path = cli.model_path; - let tokenizer_path = cli.tokenizer_path; - let param_path = cli.param_path; + let model_path = cli.model_path.clone(); + let tokenizer_path = cli.tokenizer_path.clone(); + let param_path = cli.param_path.clone(); + + #[cfg(not(feature = "server"))] + if cli.inference_server { + eprintln!("Inference server is not enabled in this build."); + return Err("Inference server is not enabled in this build.".into()); + } let max_threads: usize = match cli.max_threads { None => rayon::current_num_threads(), @@ -91,6 +120,15 @@ pub fn main() -> Result<(), Box> { colored::control::SHOULD_COLORIZE.set_override(false); } + // Custom println-like macro that respects be_quiet + macro_rules! pln { + ($($arg:tt)*) => { + if !be_quiet { + std::println!($($arg)*); + } + }; + } + #[cfg(feature = "opencl")] let opencl: Option = { let opencl_device = cli.opencl_device.unwrap_or(0); @@ -107,15 +145,6 @@ pub fn main() -> Result<(), Box> { } }; - // Custom println-like macro that respects be_quiet - macro_rules! pln { - ($($arg:tt)*) => { - if !be_quiet { - std::println!($($arg)*); - } - }; - } - // Read ModelParams from param_path, we expect it to be JSON let mut fs = std::fs::File::open(¶m_path)?; let mut bs = Vec::new(); @@ -124,12 +153,12 @@ pub fn main() -> Result<(), Box> { let params: ModelParams = serde_json::from_slice(&bs)?; pln!("Loaded model parameters from {}.", param_path); - let prompt: String = match (cli.prompt, cli.prompt_file) { - (Some(prompt), None) => { + let prompt: String = match (&cli.prompt, &cli.prompt_file) { + (Some(ref prompt), None) => { pln!("Using prompt: {}", prompt); - prompt + prompt.clone() } - (None, Some(prompt_file)) => { + (None, Some(ref prompt_file)) => { pln!("Using prompt file: {}", prompt_file); let mut fs = std::fs::File::open(prompt_file)?; let mut bs = Vec::new(); @@ -138,8 +167,12 @@ pub fn main() -> Result<(), Box> { String::from_utf8(bs)? } _ => { - eprintln!("Please provide either a prompt or a prompt file."); - return Err("Please provide either a prompt or a prompt file.".into()); + if cli.inference_server { + "".to_string() + } else { + eprintln!("Please provide either a prompt or a prompt file."); + return Err("Please provide either a prompt or a prompt file.".into()); + } } }; @@ -212,13 +245,445 @@ pub fn main() -> Result<(), Box> { )?; pln!("All is loaded. Starting inference."); + let tr: Arc = Arc::new(tr); + let tok: Arc = Arc::new(tok); + + if cli.inference_server { + #[cfg(feature = "server")] + { + server_inference(cli, tr, tok, be_quiet, max_seq_len, params, max_threads) + } + #[cfg(not(feature = "server"))] + { + eprintln!("The inference server feature is not enabled."); + eprintln!("Please enable it with the \"inference-server\" feature."); + Err("The inference server feature is not enabled.".into()) + } + } else { + command_line_inference( + cli.clone(), + tr.clone(), + tok.clone(), + prompt.clone(), + be_quiet, + max_seq_len, + params.clone(), + max_threads, + ) + } +} + +#[cfg(feature = "server")] +fn server_inference( + cli: Cli, + tr: Arc, + tok: Arc, + be_quiet: bool, + max_seq_len: usize, + _params: ModelParams, + _max_threads: usize, +) -> Result<(), Box> { + macro_rules! pln { + ($($arg:tt)*) => { + if !be_quiet { + std::println!($($arg)*); + } + }; + } + + let inference_server_port = cli.inference_server_port.unwrap_or(8080); + let inference_server_host = cli + .inference_server_host + .clone() + .unwrap_or("127.0.0.1".to_string()); + let inference_server_max_concurrent_inferences = + cli.inference_server_max_concurrent_inferences.unwrap_or(5); + let inference_server_api_path = cli + .inference_server_api_path + .clone() + .unwrap_or("/rllama/v1/inference".to_string()); + let inference_server_prompt_cache_size = cli.inference_server_prompt_cache_size.unwrap_or(50); + + pln!( + "Maximum concurrent inferences: {}", + inference_server_max_concurrent_inferences + ); + pln!("Prompt cache size: {}", inference_server_prompt_cache_size); + pln!("Maximum sequence length: {}", max_seq_len); + pln!( + "--- Starting HTTP server on {}:{}, answering to requests at {} ---", + inference_server_host, + inference_server_port, + inference_server_api_path + ); + + // If there are too many connections, they will hang until they get their turn. + // Maybe can later implement return 503 slow down or something similar. + let concurrent_requests_semaphore = Semaphore::new(inference_server_max_concurrent_inferences); + + let rocket_conf = rocket::Config::build(rocket::config::Environment::Production) + .address(inference_server_host) + .port(inference_server_port) + .finalize() + .unwrap(); + + let app = rocket::custom(rocket_conf) + .mount(&inference_server_api_path, routes![handle_request]) + .manage(InferenceServerState { + transformer: tr, + tokenizer: tok, + max_seq_len, + concurrent_requests_semaphore, + attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty( + inference_server_prompt_cache_size, + ))), + }); + + app.launch(); + panic!("Starting web server failed."); +} + +fn is_false(b: &bool) -> bool { + !b +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +struct InferenceRequest { + temperature: Option, + top_k: Option, + top_p: Option, + repetition_penalty: Option, + max_seq_len: Option, + max_new_tokens: Option, + no_token_sampling: Option, + stop_at_end_token: Option, + prompt: String, +} + +#[cfg(feature = "server")] +#[derive(Serialize, Deserialize, Clone, Debug)] +struct PredResult { + p: f32, + #[serde(skip_serializing_if = "is_false")] + is_end_token: bool, +} + +#[cfg(feature = "server")] +struct GeneratingSession { + transformer: Arc, + token_sampler: TokenSampler, + tokenizer: Arc, + attention_cache_repository: Arc>, + tokens: Vec, + req_max_seq_len: usize, + req_max_new_tokens: usize, + new_tokens_generated: usize, + prev_pos: usize, + no_token_sampling: bool, + stop_at_end_token: bool, + sent_stuff_last_time: bool, + result: Vec, // stores JSONL lines to be returned from read() +} + +#[cfg(feature = "server")] +impl GeneratingSession { + fn read_from_result(&mut self, buf: &mut [u8]) -> usize { + if !self.result.is_empty() { + if self.result.len() <= buf.len() { + for idx in 0..self.result.len() { + buf[idx] = self.result[idx]; + } + let len = self.result.len(); + self.sent_stuff_last_time = true; + self.result.truncate(0); + return len; + } else { + for idx in 0..buf.len() { + buf[idx] = self.result[idx]; + } + self.result = self.result[buf.len()..].to_vec(); + self.sent_stuff_last_time = true; + return buf.len(); + } + } + return 0; + } +} + +#[cfg(feature = "server")] +impl Read for GeneratingSession { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + if self.sent_stuff_last_time && self.result.is_empty() { + // If we return WouldBlock every time we send something, it'll cause Rocket to + // flush available data. + self.sent_stuff_last_time = false; + return Err(std::io::Error::new( + std::io::ErrorKind::WouldBlock, + "WouldBlock", + )); + } + + // Push more data to the upstream if we have something stored. + let bytes_read = self.read_from_result(buf); + if bytes_read > 0 { + return Ok(bytes_read); + } + if self.tokens.len() >= self.req_max_seq_len { + return Ok(0); + } + if self.new_tokens_generated >= self.req_max_new_tokens { + return Ok(0); + } + + let (mut caches, update_pos) = { + let mut ac = self.attention_cache_repository.write().unwrap(); + match ac.get(&self.tokens) { + Some((c, pos)) if pos >= self.prev_pos => (c.true_clone(), pos), + Some(_) => { + std::mem::drop(ac); + (self.transformer.make_caches(), 0) + } + None => { + let caches = self.transformer.make_caches(); + ac.put(self.tokens.clone(), caches.true_clone(), self.prev_pos); + (caches, self.prev_pos) + } + } + }; + if update_pos > self.prev_pos { + self.prev_pos = update_pos; + } + + assert!(self.result.is_empty()); + let predictions = + self.transformer + .forward(&self.tokens[self.prev_pos..], self.prev_pos, &mut caches); + self.prev_pos = self.tokens.len(); + let (highest_pred_idx, token_prob) = + self.token_sampler + .sample(&predictions, self.tokenizer.as_ref(), &self.tokens); + self.tokens.push(highest_pred_idx as TokenId); + { + let mut ac = self.attention_cache_repository.write().unwrap(); + ac.put(self.tokens.clone(), caches, self.prev_pos); + } + self.new_tokens_generated += 1; + let token: &str = self.tokenizer.id_to_str(highest_pred_idx as TokenId); + let mut is_end_token: bool = false; + if token == "" && self.stop_at_end_token { + self.new_tokens_generated = self.req_max_new_tokens; + is_end_token = true; + } + + let mut result: BTreeMap = BTreeMap::new(); + if self.no_token_sampling { + // All predictions go the line. + let probs = self + .token_sampler + .logits_to_btreemap(&predictions, self.tokenizer.as_ref()); + for (k, v) in probs.into_iter() { + let mut is_end_token: bool = false; + if k == "" { + is_end_token = true; + } + result.insert( + k, + PredResult { + p: v, + is_end_token: is_end_token, + }, + ); + } + // Convert to JSON + let json = serde_json::to_string(&result).unwrap(); + self.result.extend(json.as_bytes()); + self.result.push(b'\n'); + return Ok(self.read_from_result(buf)); + } else { + result.insert( + token.to_string(), + PredResult { + p: token_prob, + is_end_token, + }, + ); + let json = serde_json::to_string(&result).unwrap(); + self.result.extend(json.as_bytes()); + self.result.push(b'\n'); + return Ok(self.read_from_result(buf)); + } + } +} + +#[cfg(feature = "server")] +struct AttentionCacheRepository { + caches: BTreeMap, (TransformerCaches, usize, std::time::Instant)>, + max_sz: usize, +} + +#[cfg(feature = "server")] +impl AttentionCacheRepository { + fn empty(max_size: usize) -> AttentionCacheRepository { + AttentionCacheRepository { + caches: BTreeMap::new(), + max_sz: max_size, + } + } + + /// Makes sure the cache repository is not larger than sz, evicts any older items. + fn limit_size(&mut self, sz: usize) { + if sz == 0 { + self.caches = BTreeMap::new(); + return; + } + // Slow algorithm but I guess our cache will never be unimaginably large so it's probably + // fine + while self.caches.len() > sz { + let mut oldest_time = None; + let mut oldest_key: Option<&Vec> = None; + for (k, (_, _, time)) in self.caches.iter() { + if oldest_time.is_none() || time < oldest_time.unwrap() { + oldest_time = Some(time); + oldest_key = Some(k); + } + } + let oldest_key = oldest_key.unwrap().clone(); + self.caches.remove(&oldest_key); + } + } + + fn get(&self, tokens: &[TokenId]) -> Option<(&TransformerCaches, usize)> { + if let Some((caches, pos, _)) = self.caches.get(tokens) { + Some((caches, *pos)) + } else { + None + } + } + + fn put(&mut self, tokens: Vec, caches: TransformerCaches, prev_pos: usize) { + self.caches + .insert(tokens, (caches, prev_pos, std::time::Instant::now())); + self.limit_size(self.max_sz); + } +} + +#[cfg(feature = "server")] +#[derive(Clone)] +struct InferenceServerState { + transformer: Arc, + tokenizer: Arc, + max_seq_len: usize, + concurrent_requests_semaphore: Semaphore, + attention_cache_repository: Arc>, +} + +#[cfg(feature = "server")] +#[post("/", data = "")] +fn handle_request( + state: State, + input: Data, +) -> Result, status::BadRequest> { + let _lock = state.concurrent_requests_semaphore.acquire(); + let tr = state.transformer.clone(); + let tok = state.tokenizer.clone(); + + let mut data = input.open(); + let mut databuf: Vec = Vec::new(); + data.read_to_end(&mut databuf).unwrap(); + + // Parse the JSON out of the request + let request: InferenceRequest = match serde_json::from_slice(&databuf) { + Err(_e) => { + return Err(status::BadRequest(Some("Invalid JSON.".to_string()))); + } + Ok(ir) => ir, + }; + + let stop_at_end_token = request.stop_at_end_token.unwrap_or(true); + let temperature = request.temperature.unwrap_or(1.0); + let top_k = request.top_k.unwrap_or(20); + let top_p = request.top_p.unwrap_or(1.0); + let repetition_penalty = request.repetition_penalty.unwrap_or(1.0); + let mut req_max_seq_len = request.max_seq_len.unwrap_or(state.max_seq_len); + if req_max_seq_len > state.max_seq_len { + req_max_seq_len = state.max_seq_len; + } + let req_max_new_tokens = request.max_new_tokens.unwrap_or(20); + let no_token_sampling = request.no_token_sampling.unwrap_or(false); + let prompt = request.prompt; + + if temperature.is_nan() { + return Err(status::BadRequest(Some( + "Temperature must be a number.".to_string(), + ))); + } + if top_k == 0 { + return Err(status::BadRequest(Some( + "Top-k must be greater than 0.".to_string(), + ))); + } + if top_p.is_nan() { + return Err(status::BadRequest(Some( + "Top-p must be a number.".to_string(), + ))); + } + if repetition_penalty.is_nan() { + return Err(status::BadRequest(Some( + "Repetition penalty must be a number.".to_string(), + ))); + } + + let token_sampler = TokenSampler::new() + .temperature(temperature) + .top_p(top_p) + .top_k(top_k) + .repetition_penalty(repetition_penalty); + let toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); + let gsession = GeneratingSession { + transformer: tr, + tokenizer: tok, + attention_cache_repository: state.attention_cache_repository.clone(), + token_sampler: token_sampler, + tokens: toks_id, + req_max_seq_len: req_max_seq_len, + req_max_new_tokens: req_max_new_tokens, + new_tokens_generated: 0, + prev_pos: 0, + no_token_sampling: no_token_sampling, + stop_at_end_token: stop_at_end_token, + sent_stuff_last_time: false, + result: Vec::new(), + }; + + return Ok(rocket::response::Stream::chunked(gsession, 1024)); +} + +fn command_line_inference( + cli: Cli, + tr: Arc, + tok: Arc, + prompt: String, + be_quiet: bool, + max_seq_len: usize, + params: ModelParams, + max_threads: usize, +) -> Result<(), Box> { + // Custom println-like macro that respects be_quiet + macro_rules! pln { + ($($arg:tt)*) => { + if !be_quiet { + std::println!($($arg)*); + } + }; + } + let mut toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); let mut prev_pos = 0; let mut token_sampler = TokenSampler::new() - .temperature(0.8) - .top_p(0.9) - .top_k(50) - .repetition_penalty(0.8); + .temperature(1.0) + .top_p(1.0) + .top_k(20) + .repetition_penalty(1.0); if let Some(temperature) = cli.temperature { token_sampler = token_sampler.temperature(temperature); diff --git a/src/tensor.rs b/src/tensor.rs index 8a99834..d36c476 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -23,6 +23,7 @@ use crate::tensor_opencl_support::{OpenCL, OpenCLError, OpenCLEvent, OpenCLTenso use crate::unpickler; use crate::unpickler::UnpicklingError; use half::f16; +use lazy_static::lazy_static; use rand::Rng; use rayon::prelude::*; use std::alloc::Layout; @@ -123,12 +124,21 @@ impl Clone for Tensor { } } +// Tracks how many bytes are allocated for tensors globally on CPU. +// I've used this to debug memory leaks and monitor memory usage. +lazy_static! { + static ref TENSORS_BYTES_ALLOCATED: std::sync::atomic::AtomicUsize = + std::sync::atomic::AtomicUsize::new(0); +} + impl Drop for Tensor { fn drop(&mut self) { #[cfg(feature = "opencl")] self.process_waiting_for_data_mut(); unsafe { if !self.data.is_null() { + TENSORS_BYTES_ALLOCATED + .fetch_sub(self.layout.size(), std::sync::atomic::Ordering::Relaxed); std::alloc::dealloc(self.data, self.layout); } } @@ -342,6 +352,7 @@ impl Tensor { if data.is_null() { panic!("Failed to allocate tensor"); } + TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed); // Even though we are uninitialized, we should zero out the extra space between the // columns. // Otherwise there might be problems later as other operations assume it is zeroed. @@ -1384,12 +1395,14 @@ impl Tensor { as *const I16x8, ), ) - } else { + } else if row < nrows { ( load_i16x8(ptr.add(row * cols_capacity + column) as *const I16x8), i16x8_zero(), ) + } else { + (i16x8_zero(), i16x8_zero()) }; let left: F32x8 = i16x8_as_f16_to_f32x8(left); let right: F32x8 = i16x8_as_f16_to_f32x8(right); @@ -1840,6 +1853,7 @@ impl Tensor { if data.is_null() { panic!("Failed to allocate tensor"); } + TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed); Self { data, #[cfg(feature = "opencl")] @@ -2005,6 +2019,7 @@ impl Tensor { if data.is_null() { panic!("to_cpu_inplace: Failed to allocate tensor"); } + TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed); let ev = od.as_mut().unwrap().data_u16_from_gpu(data as *mut u16)?; self.data = data as *mut u16 as *mut u8; self.waiting_for_data = Some(ev); diff --git a/src/token_sampler.rs b/src/token_sampler.rs index 1b760ad..138aaf3 100644 --- a/src/token_sampler.rs +++ b/src/token_sampler.rs @@ -65,6 +65,21 @@ impl TokenSampler { } } + pub fn logits_to_btreemap( + &self, + logits: &Tensor, + tokenizer: &Tokenizer, + ) -> BTreeMap { + let mut result = BTreeMap::new(); + for token_idx in 0..logits.rows() { + result.insert( + tokenizer.id_to_str(token_idx as TokenId).to_string(), + logits.get_f32(token_idx, 0), + ); + } + result + } + pub fn sample( &self, logits: &Tensor, diff --git a/src/transformer.rs b/src/transformer.rs index 7836f6c..feb99d1 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -28,6 +28,8 @@ pub struct Transformer { output: Tensor, layers: Vec, + + data_settings: DataSettings, } // Clone is cheap @@ -94,25 +96,59 @@ pub struct TransformerBlock { pub struct AttentionCache { cache_k: Vec>>, cache_v: Vec>>, + data_settings: DataSettings, } impl AttentionCache { - fn new(max_seq_len: usize, n_local_heads: usize, head_dim: usize) -> Self { + fn new( + max_seq_len: usize, + n_local_heads: usize, + head_dim: usize, + data_settings: &DataSettings, + ) -> Self { let mut cache_k = Vec::with_capacity(n_local_heads); let mut cache_v = Vec::with_capacity(n_local_heads); + + let dtype = if data_settings.force_f16 { + TensorDType::Float16 + } else { + TensorDType::Float32 + }; for _ in 0..n_local_heads { cache_k.push(Arc::new(RwLock::new(Tensor::zeros( head_dim as i64, max_seq_len as i64, - TensorDType::Float32, + dtype, )))); cache_v.push(Arc::new(RwLock::new(Tensor::zeros( head_dim as i64, max_seq_len as i64, - TensorDType::Float32, + dtype, )))); } - AttentionCache { cache_k, cache_v } + AttentionCache { + cache_k, + cache_v, + data_settings: data_settings.clone(), + } + } + + /// Cloning AttentionCache normally just makes new references to the same cache. + /// This creates a true clone with copied tensors. + fn true_clone(&self) -> AttentionCache { + let mut cache_k = Vec::with_capacity(self.cache_k.len()); + let mut cache_v = Vec::with_capacity(self.cache_v.len()); + for idx in 0..self.cache_k.len() { + let old_k = self.cache_k[idx].read().unwrap(); + cache_k.push(Arc::new(RwLock::new(old_k.clone()))); + let old_v = self.cache_v[idx].read().unwrap(); + cache_v.push(Arc::new(RwLock::new(old_v.clone()))); + } + AttentionCache { + cache_k, + cache_v, + data_settings: self.data_settings.clone(), + } } fn shift_left(&mut self, shifts: usize) { @@ -141,6 +177,14 @@ impl TransformerCaches { layer.shift_left(shifts); } } + + pub fn true_clone(&self) -> TransformerCaches { + let mut layer_caches = Vec::with_capacity(self.layer_caches.len()); + for layer in self.layer_caches.iter() { + layer_caches.push(layer.true_clone()); + } + TransformerCaches { layer_caches } + } } pub struct RMSNorm { @@ -218,6 +262,7 @@ impl Transformer { Ok(Transformer { freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0), + data_settings: data_settings.clone(), emb, dim, n_layers, @@ -240,6 +285,7 @@ impl Transformer { self.max_seq_len, self.n_local_heads, self.head_dim, + &self.data_settings, )); } TransformerCaches { @@ -664,6 +710,9 @@ impl Attention { let keys = cache_k.clip_cols(start_pos + seq_len as usize); let values = cache_v.clip_cols(start_pos + seq_len as usize); + let keys = keys.into_same_type(&xq_row); + let values = values.into_same_type(&xq_row); + let m = xq_row .matrix_mul(&keys) .scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt());