Add simple HTTP API support.

It took annoyingly a lot of effort just to make this simple server.

I tried rouille web framework first, but it didn't support getting
chunked output to the client line-by-line. (seems that if it exposed
more details about the underlying tiny-http package I could have hacked
it to work).

I went with Rocket because it had less async stuff and seemed decent.

I got weird issues where it seemed as if memory use kept increasing and
increasing. I may have got that fixed but I couldn't figure out what
made it use so much memory, even tools like valgrind and heaptrack told
me there isn't that much memory allocated but I can see RES increasing
in `htop`.

Switched to MiMalloc as it seems to slightly decrease memory use.

Added details about the inference server to README.md. And also added an
example Python script of it.

I want to use this feature to later investigate how much do
quantizations or f16/f32 affect output. Easier to do such things on
Python.
master
Mikko Juola 3 years ago
parent 9c86c17318
commit b9be485610

661
Cargo.lock generated

@ -2,6 +2,60 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "aead"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fc95d1bdb8e6666b2b217308eeeb09f2d6728d104be3e31916cc74d15420331"
dependencies = [
"generic-array",
]
[[package]]
name = "aes"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "884391ef1066acaa41e766ba8f596341b96e93ce34f9a43e7d24bf0a0eaf0561"
dependencies = [
"aes-soft",
"aesni",
"cipher",
]
[[package]]
name = "aes-gcm"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5278b5fabbb9bd46e24aa69b2fdea62c99088e0a950a9be40e3e0101298f88da"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "aes-soft"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be14c7498ea50828a38d0e24a765ed2effe92a705885b57d029cd67d45744072"
dependencies = [
"cipher",
"opaque-debug",
]
[[package]]
name = "aesni"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea2e11f5e94c2f7d386164cc2aa1f97823fed6f259e486940a71c174dd01b0ce"
dependencies = [
"cipher",
"opaque-debug",
]
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "0.7.20" version = "0.7.20"
@ -49,18 +103,49 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "base64"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "489d6c0ed21b11d038c31b6ceccca973e65d73ba3bd8ecb9a2babf5546164643"
dependencies = [
"byteorder",
"safemem",
]
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "block-buffer"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4"
dependencies = [
"generic-array",
]
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
version = "3.12.0" version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]] [[package]]
name = "cast" name = "cast"
version = "0.3.0" version = "0.3.0"
@ -106,6 +191,15 @@ dependencies = [
"half 1.8.2", "half 1.8.2",
] ]
[[package]]
name = "cipher"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12f8e7987cbd042a63249497f41aed09f8e65add917ea6566effbc56578d6801"
dependencies = [
"generic-array",
]
[[package]] [[package]]
name = "cl-sys" name = "cl-sys"
version = "0.4.2" version = "0.4.2"
@ -150,8 +244,8 @@ checksum = "fddf67631444a3a3e3e5ac51c36a5e01335302de677bd78759eaa90ab1f46644"
dependencies = [ dependencies = [
"heck", "heck",
"proc-macro-error", "proc-macro-error",
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"syn 1.0.109", "syn 1.0.109",
] ]
@ -197,6 +291,37 @@ dependencies = [
"windows-sys 0.42.0", "windows-sys 0.42.0",
] ]
[[package]]
name = "cookie"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be2018768ed1d848cc4d347d551546474025ba820e5db70e4c9aaa349f678bd7"
dependencies = [
"aes-gcm",
"base64 0.13.1",
"hkdf",
"hmac",
"percent-encoding 2.2.0",
"rand",
"sha2",
"time",
]
[[package]]
name = "cpufeatures"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320"
dependencies = [
"libc",
]
[[package]]
name = "cpuid-bool"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcb25d077389e53838a8158c8e99174c5a9d902dee4904320db714f3c653ffba"
[[package]] [[package]]
name = "criterion" name = "criterion"
version = "0.4.0" version = "0.4.0"
@ -306,6 +431,66 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "crypto-mac"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bff07008ec701e8028e2ceb8f83f0e4274ee62bd2dbdc4fefff2e9a91824081a"
dependencies = [
"generic-array",
"subtle",
]
[[package]]
name = "ctr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb4a30d54f7443bf3d6191dcd486aca19e67cb3c49fa7a06a319966346707e7f"
dependencies = [
"cipher",
]
[[package]]
name = "devise"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd716c4a507adc5a2aa7c2a372d06c7497727e0892b243d3036bc7478a13e526"
dependencies = [
"devise_codegen",
"devise_core",
]
[[package]]
name = "devise_codegen"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea7b8290d118127c08e3669da20b331bed56b09f20be5945b7da6c116d8fab53"
dependencies = [
"devise_core",
"quote 0.6.13",
]
[[package]]
name = "devise_core"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1053e9d5d5aade9bcedb5ab53b78df2b56ff9408a3138ce77eaaef87f932373"
dependencies = [
"bitflags",
"proc-macro2 0.4.30",
"quote 0.6.13",
"syn 0.15.44",
]
[[package]]
name = "digest"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066"
dependencies = [
"generic-array",
]
[[package]] [[package]]
name = "either" name = "either"
version = "1.8.1" version = "1.8.1"
@ -387,6 +572,16 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a"
[[package]]
name = "generic-array"
version = "0.14.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9"
dependencies = [
"typenum",
"version_check 0.9.4",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.8" version = "0.2.8"
@ -395,9 +590,25 @@ checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"wasi", "wasi 0.11.0+wasi-snapshot-preview1",
] ]
[[package]]
name = "ghash"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97304e4cd182c3846f7575ced3890c53012ce534ad9114046b0a9e00bb30a375"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "half" name = "half"
version = "1.8.2" version = "1.8.2"
@ -449,6 +660,62 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]]
name = "hkdf"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51ab2f639c231793c5f6114bdb9bbe50a7dbbfcd7c7c6bd8475dec2d991e964f"
dependencies = [
"digest",
"hmac",
]
[[package]]
name = "hmac"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15"
dependencies = [
"crypto-mac",
"digest",
]
[[package]]
name = "httparse"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "hyper"
version = "0.10.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a0652d9a2609a968c14be1a9ea00bf4b1d64e2e1f53a1b51b6fff3a6e829273"
dependencies = [
"base64 0.9.3",
"httparse",
"language-tags",
"log 0.3.9",
"mime",
"num_cpus",
"time",
"traitobject",
"typeable",
"unicase",
"url",
]
[[package]]
name = "idna"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38f09e0f0b1fb55fdee1f17470ad800da77af5186a1a76c026b679358b7e844e"
dependencies = [
"matches",
"unicode-bidi",
"unicode-normalization",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.9.2" version = "1.9.2"
@ -527,6 +794,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "language-tags"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a91d884b6667cd606bb5a69aa0c99ba811a115fc68915e7056ec08a46e93199a"
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"
@ -539,12 +812,31 @@ version = "0.2.140"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c"
[[package]]
name = "libmimalloc-sys"
version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8c7cbf8b89019683667e347572e6d55a7df7ea36b0c4ce69961b0cde67b174"
dependencies = [
"cc",
"libc",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.1.4" version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4"
[[package]]
name = "log"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e19e8d5c34a3e0e2223db8e060f9e8264aeeb5c5fc64a4ee9965c062211c024b"
dependencies = [
"log 0.4.17",
]
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.17" version = "0.4.17"
@ -554,6 +846,12 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "matches"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.5.0" version = "2.5.0"
@ -569,6 +867,24 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "mimalloc"
version = "0.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dcb174b18635f7561a0c6c9fc2ce57218ac7523cf72c50af80e2d79ab8f3ba1"
dependencies = [
"libmimalloc-sys",
]
[[package]]
name = "mime"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba626b8a6de5da682e1caa06bdb42a335aee5a84db8e5046a3e8ab17ba0a3ae0"
dependencies = [
"log 0.3.9",
]
[[package]] [[package]]
name = "nodrop" name = "nodrop"
version = "0.1.14" version = "0.1.14"
@ -669,12 +985,52 @@ version = "11.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "opaque-debug"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]] [[package]]
name = "os_str_bytes" name = "os_str_bytes"
version = "6.4.1" version = "6.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee"
[[package]]
name = "pear"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32dfa7458144c6af7f9ce6a137ef975466aa68ffa44d4d816ee5934018ba960a"
dependencies = [
"pear_codegen",
]
[[package]]
name = "pear_codegen"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0288ba5d581afbc93e2bbd931c1013584c15ecf46b1cdb927edc7abddbc8ca6"
dependencies = [
"proc-macro2 0.4.30",
"quote 0.6.13",
"syn 0.15.44",
"version_check 0.9.4",
"yansi",
]
[[package]]
name = "percent-encoding"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831"
[[package]]
name = "percent-encoding"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]] [[package]]
name = "plotters" name = "plotters"
version = "0.3.4" version = "0.3.4"
@ -703,6 +1059,17 @@ dependencies = [
"plotters-backend", "plotters-backend",
] ]
[[package]]
name = "polyval"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eebcc4aa140b9abd2bc40d9c3f7ccec842679cd79045ac3a7ac698c1a064b7cd"
dependencies = [
"cpuid-bool",
"opaque-debug",
"universal-hash",
]
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "0.3.19" version = "0.3.19"
@ -722,10 +1089,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [ dependencies = [
"proc-macro-error-attr", "proc-macro-error-attr",
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"syn 1.0.109", "syn 1.0.109",
"version_check", "version_check 0.9.4",
] ]
[[package]] [[package]]
@ -734,9 +1101,18 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"version_check", "version_check 0.9.4",
]
[[package]]
name = "proc-macro2"
version = "0.4.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759"
dependencies = [
"unicode-xid",
] ]
[[package]] [[package]]
@ -782,7 +1158,7 @@ checksum = "9d39b14605eaa1f6a340aec7f320b34064feb26c93aec35d6a9a2272a8ddfa49"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"indexmap", "indexmap",
"log", "log 0.4.17",
"protobuf", "protobuf",
"protobuf-support", "protobuf-support",
"tempfile", "tempfile",
@ -799,13 +1175,22 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "quote"
version = "0.6.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1"
dependencies = [
"proc-macro2 0.4.30",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.26" version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2 1.0.52",
] ]
[[package]] [[package]]
@ -907,6 +1292,8 @@ dependencies = [
"embedded-profiling", "embedded-profiling",
"half 2.2.1", "half 2.2.1",
"indicatif", "indicatif",
"lazy_static",
"mimalloc",
"num-complex", "num-complex",
"ocl", "ocl",
"protobuf", "protobuf",
@ -914,11 +1301,65 @@ dependencies = [
"protobuf-parse", "protobuf-parse",
"rand", "rand",
"rayon", "rayon",
"rocket",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror",
] ]
[[package]]
name = "rocket"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83b9d9dc08c5dcc1d8126a9dd615545e6a358f8c13c883c8dfed8c0376fa355e"
dependencies = [
"atty",
"base64 0.13.1",
"log 0.4.17",
"memchr",
"num_cpus",
"pear",
"rocket_codegen",
"rocket_http",
"state",
"time",
"toml",
"version_check 0.9.4",
"yansi",
]
[[package]]
name = "rocket_codegen"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2810037b5820098af97bd4fdd309e76a8101ceb178147de775c835a2537284fe"
dependencies = [
"devise",
"glob",
"indexmap",
"quote 0.6.13",
"rocket_http",
"version_check 0.9.4",
"yansi",
]
[[package]]
name = "rocket_http"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bf9cbd128e1f321a2d0bebd2b7cf0aafd89ca43edf69e49b56a5c46e48eb19f"
dependencies = [
"cookie",
"hyper",
"indexmap",
"pear",
"percent-encoding 1.0.1",
"smallvec",
"state",
"time",
"unicode-xid",
]
[[package]] [[package]]
name = "rustc_version" name = "rustc_version"
version = "0.4.0" version = "0.4.0"
@ -948,6 +1389,12 @@ version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]]
name = "safemem"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072"
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@ -984,8 +1431,8 @@ version = "1.0.157"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78997f4555c22a7971214540c4a661291970619afd56de19f77e0de86296e1e5" checksum = "78997f4555c22a7971214540c4a661291970619afd56de19f77e0de86296e1e5"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"syn 2.0.0", "syn 2.0.0",
] ]
@ -1000,20 +1447,62 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sha2"
version = "0.9.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800"
dependencies = [
"block-buffer",
"cfg-if",
"cpufeatures",
"digest",
"opaque-debug",
]
[[package]]
name = "smallvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
[[package]]
name = "state"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3015a7d0a5fd5105c91c3710d42f9ccf0abfb287d62206484dcc67f9569a6483"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.10.0" version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subtle"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
[[package]]
name = "syn"
version = "0.15.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ca4b3b69a77cbe1ffc9e198781b7acb0c7365a883670e8f1c1bc66fba79a5c5"
dependencies = [
"proc-macro2 0.4.30",
"quote 0.6.13",
"unicode-xid",
]
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.109" version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"unicode-ident", "unicode-ident",
] ]
@ -1023,8 +1512,8 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cff13bb1732bccfe3b246f3fdb09edfd51c01d6f5299b7ccd9457c2e4e37774" checksum = "4cff13bb1732bccfe3b246f3fdb09edfd51c01d6f5299b7ccd9457c2e4e37774"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"unicode-ident", "unicode-ident",
] ]
@ -1071,11 +1560,22 @@ version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"syn 2.0.0", "syn 2.0.0",
] ]
[[package]]
name = "time"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a"
dependencies = [
"libc",
"wasi 0.10.0+wasi-snapshot-preview1",
"winapi",
]
[[package]] [[package]]
name = "tinytemplate" name = "tinytemplate"
version = "1.2.1" version = "1.2.1"
@ -1086,18 +1586,117 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "tinyvec"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "toml"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "758664fc71a3a69038656bee8b6be6477d2a6c315a6b81f7081f591bffa4111f"
dependencies = [
"serde",
]
[[package]]
name = "traitobject"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079"
[[package]]
name = "typeable"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1410f6f91f21d1612654e7cc69193b0334f909dcf2c790c4826254fbb86f8887"
[[package]]
name = "typenum"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]]
name = "unicase"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f4765f83163b74f957c797ad9253caf97f103fb064d3999aea9568d09fc8a33"
dependencies = [
"version_check 0.1.5",
]
[[package]]
name = "unicode-bidi"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d502c968c6a838ead8e69b2ee18ec708802f99db92a0d156705ec9ef801993b"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.8" version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4"
[[package]]
name = "unicode-normalization"
version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921"
dependencies = [
"tinyvec",
]
[[package]] [[package]]
name = "unicode-width" name = "unicode-width"
version = "0.1.10" version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
[[package]]
name = "unicode-xid"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc"
[[package]]
name = "universal-hash"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f214e8f697e925001e66ec2c6e37a4ef93f0f78c2eed7814394e10c62025b05"
dependencies = [
"generic-array",
"subtle",
]
[[package]]
name = "url"
version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd4e7c0d531266369519a4aa4f399d748bd37043b00bde1e4ff1f60a120b355a"
dependencies = [
"idna",
"matches",
"percent-encoding 1.0.1",
]
[[package]]
name = "version_check"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.4"
@ -1114,6 +1713,12 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "wasi"
version = "0.10.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f"
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"
@ -1137,10 +1742,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9"
dependencies = [ dependencies = [
"bumpalo", "bumpalo",
"log", "log 0.4.17",
"once_cell", "once_cell",
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"syn 1.0.109", "syn 1.0.109",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -1151,7 +1756,7 @@ version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5"
dependencies = [ dependencies = [
"quote", "quote 1.0.26",
"wasm-bindgen-macro-support", "wasm-bindgen-macro-support",
] ]
@ -1161,8 +1766,8 @@ version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2 1.0.52",
"quote", "quote 1.0.26",
"syn 1.0.109", "syn 1.0.109",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
@ -1306,3 +1911,9 @@ name = "windows_x86_64_msvc"
version = "0.42.2" version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "yansi"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec"

@ -32,10 +32,14 @@ indicatif = "0.17"
colored = "2" colored = "2"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
mimalloc = "0.1"
ocl = { version = "0.19", optional = true } ocl = { version = "0.19", optional = true }
rocket = { version = "0.4", features = ["sse"], optional = true }
lazy_static = "1.4"
[features] [features]
opencl = ["ocl"] opencl = ["ocl"]
server = ["rocket"]
# We need protobuf compiler # We need protobuf compiler
[build-dependencies] [build-dependencies]
@ -46,6 +50,7 @@ protobuf-parse = "3.2"
criterion = "0.4" criterion = "0.4"
[profile.release] [profile.release]
panic = 'abort'
debug = true debug = true
[[bench]] [[bench]]

@ -90,6 +90,92 @@ rllama --tokenizer-model /path/to/tokenizer.model \
Use `rllama --help` to see all the options. Use `rllama --help` to see all the options.
## Inference server
`rllama` can run in an inference server mode with a simple HTTP JSON API.
The command line flags for this are:
* `--inference-server` using this will turn on the inference server.
* `--inference-server-port` sets the port. Default port is 8080.
* `--inference-server-host` sets the host. The default host is 127.0.0.1.
* `--inference-server-max-concurrent-inferences` sets how many concurrent
requests are allowed to be actively doing inference at the same time. The
default is 5.
* `--inference-server-api-path` sets which path servers the API requests. The
default path is `/rllama/v1/inference`
* `--inference-server-prompt-cache-size` sets how many previous prompt
calculations should be cached. Default is 1000. This speeds up token
generation for prompts that were already requested before.
Prompts and flags related to token sampling are all ignored in inference server
mode. Instead, they are obtained from each HTTP JSON API request.
### Inference server API
There is an `examples/api_hello_world.py` for a minimal API use example.
```
POST /rllama/v1/inference
```
Expects a JSON body and `Accept: application/json` or `Accept: text/jsonl`.
The expected JSON is as follows:
```json
{
"temperature": <number, optional>
"top_k": <integer, optional, default 20>
"top_p": <number, optional, default: 1.0>
"repetition_penalty": <number, optional, default: 1.0>
"stop_at_end_token": <bool, optional, default: true>
"max_seq_len": <integer, optional, default: 1024. Clamped to
be at highest the same as --max-seq-len command line option.>
"max_new_tokens": <integer, optional, default: 1024>
"no_token_sampling": <bool, optional, default: false>
"prompt": <string, required>
}
```
The form of the response depends on if `no_token_sampling` is set to true or false. The
response is in JSONL, i.e. multiple JSON dictionaries, separated by newlines.
`no_token_sampling` can turn off `rllama`'s own token sampling. In this case,
the probabilities for every token are returned instead.
When no\_token\_sampling = false:
```json
{<token string>: {"p": <number>, "is_end_token": bool, might not be present}}
```
* `token` contains the new token to be appended to output. It does not
include string you fed to the system originally.
* `p` is the probability that this token was chosen. For example, if this
value is 0.1, it means that this particular token had 10% chance of being
selected with the current token sampling settings.
* `is_end_token` is `true` is the given token signifies end of output. This
field is not present otherwise.
When no\_token\_sampling = true:
```json
{<token string>: {"p": <number>, "is_end_token": bool, might not be present} \
,<token string>: {"p": <number>, "is_end_token": bool, might not be present} \
,...}
```
Tokens where `p = 0` will not be present in the JSON output.
If you want to implement your own token sampling, you may want to set
`max_new_tokens=1` and `stop_at_end_token=false` to suppress rllama's own
sampling behavior entirely.
`rllama` internally caches recently queried prompts and the intermediate
computations so that it's able to continue off quickly if you issue a query
that is either the same as a previous query or a continuation of one.
## How to turn on OpenCL ## How to turn on OpenCL
Use `opencl` Cargo feature. Use `opencl` Cargo feature.

@ -0,0 +1,25 @@
#!/usr/bin/env python3
"""
This script uses the rllama API to generate tokens.
It does not print the tokens nicely.
"""
import requests
def main():
url = 'http://127.0.0.1:8080/rllama/v1/inference'
req = {
'prompt': 'Hello world!',
'max_seq_len': 1024,
'max_new_tokens': 200,
'no_token_sampling': False
}
res = requests.post(url, json=req, stream=True)
for line in res.iter_lines():
print(line.decode('utf-8'))
if __name__ == '__main__':
main()

@ -1,8 +1,10 @@
#![feature(stdsimd)] #![feature(stdsimd)]
#![feature(decl_macro)]
pub mod embedding; pub mod embedding;
pub mod protomodels; pub mod protomodels;
pub mod rllama_main; pub mod rllama_main;
pub mod semaphore;
pub mod simd_support; pub mod simd_support;
pub mod tensor; pub mod tensor;
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -11,3 +13,6 @@ pub mod token_sampler;
pub mod tokenizer; pub mod tokenizer;
pub mod transformer; pub mod transformer;
pub mod unpickler; pub mod unpickler;
#[cfg(feature = "server")]
#[macro_use]
extern crate rocket;

@ -7,6 +7,11 @@ compile_error!("This library assumes availability of AVX and must be compiled wi
#[cfg(not(target_feature = "avx"))] #[cfg(not(target_feature = "avx"))]
compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2"); compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2");
use mimalloc::MiMalloc;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
pub fn main() -> Result<(), Box<dyn std::error::Error>> { pub fn main() -> Result<(), Box<dyn std::error::Error>> {
rllama::rllama_main::main() rllama::rllama_main::main()
} }

@ -1,18 +1,23 @@
use crate::embedding::Embedding; use crate::embedding::Embedding;
use crate::semaphore::Semaphore;
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
use crate::tensor_opencl_support::OpenCL; use crate::tensor_opencl_support::OpenCL;
use crate::token_sampler::TokenSampler; use crate::token_sampler::TokenSampler;
use crate::tokenizer::{TokenId, Tokenizer}; use crate::tokenizer::{TokenId, Tokenizer};
use crate::transformer::{DataSettings, Transformer}; use crate::transformer::{DataSettings, Transformer, TransformerCaches};
use crate::unpickler; use crate::unpickler;
use crate::unpickler::Value; use crate::unpickler::Value;
use clap::Parser; use clap::Parser;
use colored::Colorize; use colored::Colorize;
#[cfg(feature = "server")]
use rocket::{response::status, response::Stream, Data, State};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, RwLock};
#[derive(Parser)] #[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Cli { struct Cli {
#[arg(long)] #[arg(long)]
@ -51,6 +56,24 @@ struct Cli {
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
#[arg(long)] #[arg(long)]
opencl_device: Option<usize>, opencl_device: Option<usize>,
#[arg(long, action)]
inference_server: bool,
#[arg(long)]
inference_server_port: Option<u16>,
#[arg(long)]
inference_server_host: Option<String>,
#[arg(long)]
inference_server_max_concurrent_inferences: Option<usize>,
#[arg(long)]
inference_server_api_path: Option<String>,
#[arg(long)]
inference_server_prompt_cache_size: Option<usize>,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -65,9 +88,15 @@ struct ModelParams {
pub fn main() -> Result<(), Box<dyn std::error::Error>> { pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse(); let cli = Cli::parse();
let model_path = cli.model_path; let model_path = cli.model_path.clone();
let tokenizer_path = cli.tokenizer_path; let tokenizer_path = cli.tokenizer_path.clone();
let param_path = cli.param_path; let param_path = cli.param_path.clone();
#[cfg(not(feature = "server"))]
if cli.inference_server {
eprintln!("Inference server is not enabled in this build.");
return Err("Inference server is not enabled in this build.".into());
}
let max_threads: usize = match cli.max_threads { let max_threads: usize = match cli.max_threads {
None => rayon::current_num_threads(), None => rayon::current_num_threads(),
@ -91,6 +120,15 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
colored::control::SHOULD_COLORIZE.set_override(false); colored::control::SHOULD_COLORIZE.set_override(false);
} }
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
let opencl: Option<OpenCL> = { let opencl: Option<OpenCL> = {
let opencl_device = cli.opencl_device.unwrap_or(0); let opencl_device = cli.opencl_device.unwrap_or(0);
@ -107,15 +145,6 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
}; };
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
// Read ModelParams from param_path, we expect it to be JSON // Read ModelParams from param_path, we expect it to be JSON
let mut fs = std::fs::File::open(&param_path)?; let mut fs = std::fs::File::open(&param_path)?;
let mut bs = Vec::new(); let mut bs = Vec::new();
@ -124,12 +153,12 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let params: ModelParams = serde_json::from_slice(&bs)?; let params: ModelParams = serde_json::from_slice(&bs)?;
pln!("Loaded model parameters from {}.", param_path); pln!("Loaded model parameters from {}.", param_path);
let prompt: String = match (cli.prompt, cli.prompt_file) { let prompt: String = match (&cli.prompt, &cli.prompt_file) {
(Some(prompt), None) => { (Some(ref prompt), None) => {
pln!("Using prompt: {}", prompt); pln!("Using prompt: {}", prompt);
prompt prompt.clone()
} }
(None, Some(prompt_file)) => { (None, Some(ref prompt_file)) => {
pln!("Using prompt file: {}", prompt_file); pln!("Using prompt file: {}", prompt_file);
let mut fs = std::fs::File::open(prompt_file)?; let mut fs = std::fs::File::open(prompt_file)?;
let mut bs = Vec::new(); let mut bs = Vec::new();
@ -138,9 +167,13 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
String::from_utf8(bs)? String::from_utf8(bs)?
} }
_ => { _ => {
if cli.inference_server {
"".to_string()
} else {
eprintln!("Please provide either a prompt or a prompt file."); eprintln!("Please provide either a prompt or a prompt file.");
return Err("Please provide either a prompt or a prompt file.".into()); return Err("Please provide either a prompt or a prompt file.".into());
} }
}
}; };
pln!("Starting up. Loading tokenizer from {}...", tokenizer_path); pln!("Starting up. Loading tokenizer from {}...", tokenizer_path);
@ -212,13 +245,445 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
)?; )?;
pln!("All is loaded. Starting inference."); pln!("All is loaded. Starting inference.");
let tr: Arc<Transformer> = Arc::new(tr);
let tok: Arc<Tokenizer> = Arc::new(tok);
if cli.inference_server {
#[cfg(feature = "server")]
{
server_inference(cli, tr, tok, be_quiet, max_seq_len, params, max_threads)
}
#[cfg(not(feature = "server"))]
{
eprintln!("The inference server feature is not enabled.");
eprintln!("Please enable it with the \"inference-server\" feature.");
Err("The inference server feature is not enabled.".into())
}
} else {
command_line_inference(
cli.clone(),
tr.clone(),
tok.clone(),
prompt.clone(),
be_quiet,
max_seq_len,
params.clone(),
max_threads,
)
}
}
#[cfg(feature = "server")]
fn server_inference(
cli: Cli,
tr: Arc<Transformer>,
tok: Arc<Tokenizer>,
be_quiet: bool,
max_seq_len: usize,
_params: ModelParams,
_max_threads: usize,
) -> Result<(), Box<dyn std::error::Error>> {
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
let inference_server_port = cli.inference_server_port.unwrap_or(8080);
let inference_server_host = cli
.inference_server_host
.clone()
.unwrap_or("127.0.0.1".to_string());
let inference_server_max_concurrent_inferences =
cli.inference_server_max_concurrent_inferences.unwrap_or(5);
let inference_server_api_path = cli
.inference_server_api_path
.clone()
.unwrap_or("/rllama/v1/inference".to_string());
let inference_server_prompt_cache_size = cli.inference_server_prompt_cache_size.unwrap_or(50);
pln!(
"Maximum concurrent inferences: {}",
inference_server_max_concurrent_inferences
);
pln!("Prompt cache size: {}", inference_server_prompt_cache_size);
pln!("Maximum sequence length: {}", max_seq_len);
pln!(
"--- Starting HTTP server on {}:{}, answering to requests at {} ---",
inference_server_host,
inference_server_port,
inference_server_api_path
);
// If there are too many connections, they will hang until they get their turn.
// Maybe can later implement return 503 slow down or something similar.
let concurrent_requests_semaphore = Semaphore::new(inference_server_max_concurrent_inferences);
let rocket_conf = rocket::Config::build(rocket::config::Environment::Production)
.address(inference_server_host)
.port(inference_server_port)
.finalize()
.unwrap();
let app = rocket::custom(rocket_conf)
.mount(&inference_server_api_path, routes![handle_request])
.manage(InferenceServerState {
transformer: tr,
tokenizer: tok,
max_seq_len,
concurrent_requests_semaphore,
attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty(
inference_server_prompt_cache_size,
))),
});
app.launch();
panic!("Starting web server failed.");
}
fn is_false(b: &bool) -> bool {
!b
}
#[derive(Serialize, Deserialize, Clone, Debug)]
struct InferenceRequest {
temperature: Option<f32>,
top_k: Option<usize>,
top_p: Option<f32>,
repetition_penalty: Option<f32>,
max_seq_len: Option<usize>,
max_new_tokens: Option<usize>,
no_token_sampling: Option<bool>,
stop_at_end_token: Option<bool>,
prompt: String,
}
#[cfg(feature = "server")]
#[derive(Serialize, Deserialize, Clone, Debug)]
struct PredResult {
p: f32,
#[serde(skip_serializing_if = "is_false")]
is_end_token: bool,
}
#[cfg(feature = "server")]
struct GeneratingSession {
transformer: Arc<Transformer>,
token_sampler: TokenSampler,
tokenizer: Arc<Tokenizer>,
attention_cache_repository: Arc<RwLock<AttentionCacheRepository>>,
tokens: Vec<TokenId>,
req_max_seq_len: usize,
req_max_new_tokens: usize,
new_tokens_generated: usize,
prev_pos: usize,
no_token_sampling: bool,
stop_at_end_token: bool,
sent_stuff_last_time: bool,
result: Vec<u8>, // stores JSONL lines to be returned from read()
}
#[cfg(feature = "server")]
impl GeneratingSession {
fn read_from_result(&mut self, buf: &mut [u8]) -> usize {
if !self.result.is_empty() {
if self.result.len() <= buf.len() {
for idx in 0..self.result.len() {
buf[idx] = self.result[idx];
}
let len = self.result.len();
self.sent_stuff_last_time = true;
self.result.truncate(0);
return len;
} else {
for idx in 0..buf.len() {
buf[idx] = self.result[idx];
}
self.result = self.result[buf.len()..].to_vec();
self.sent_stuff_last_time = true;
return buf.len();
}
}
return 0;
}
}
#[cfg(feature = "server")]
impl Read for GeneratingSession {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.sent_stuff_last_time && self.result.is_empty() {
// If we return WouldBlock every time we send something, it'll cause Rocket to
// flush available data.
self.sent_stuff_last_time = false;
return Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"WouldBlock",
));
}
// Push more data to the upstream if we have something stored.
let bytes_read = self.read_from_result(buf);
if bytes_read > 0 {
return Ok(bytes_read);
}
if self.tokens.len() >= self.req_max_seq_len {
return Ok(0);
}
if self.new_tokens_generated >= self.req_max_new_tokens {
return Ok(0);
}
let (mut caches, update_pos) = {
let mut ac = self.attention_cache_repository.write().unwrap();
match ac.get(&self.tokens) {
Some((c, pos)) if pos >= self.prev_pos => (c.true_clone(), pos),
Some(_) => {
std::mem::drop(ac);
(self.transformer.make_caches(), 0)
}
None => {
let caches = self.transformer.make_caches();
ac.put(self.tokens.clone(), caches.true_clone(), self.prev_pos);
(caches, self.prev_pos)
}
}
};
if update_pos > self.prev_pos {
self.prev_pos = update_pos;
}
assert!(self.result.is_empty());
let predictions =
self.transformer
.forward(&self.tokens[self.prev_pos..], self.prev_pos, &mut caches);
self.prev_pos = self.tokens.len();
let (highest_pred_idx, token_prob) =
self.token_sampler
.sample(&predictions, self.tokenizer.as_ref(), &self.tokens);
self.tokens.push(highest_pred_idx as TokenId);
{
let mut ac = self.attention_cache_repository.write().unwrap();
ac.put(self.tokens.clone(), caches, self.prev_pos);
}
self.new_tokens_generated += 1;
let token: &str = self.tokenizer.id_to_str(highest_pred_idx as TokenId);
let mut is_end_token: bool = false;
if token == "</s>" && self.stop_at_end_token {
self.new_tokens_generated = self.req_max_new_tokens;
is_end_token = true;
}
let mut result: BTreeMap<String, PredResult> = BTreeMap::new();
if self.no_token_sampling {
// All predictions go the line.
let probs = self
.token_sampler
.logits_to_btreemap(&predictions, self.tokenizer.as_ref());
for (k, v) in probs.into_iter() {
let mut is_end_token: bool = false;
if k == "</s>" {
is_end_token = true;
}
result.insert(
k,
PredResult {
p: v,
is_end_token: is_end_token,
},
);
}
// Convert to JSON
let json = serde_json::to_string(&result).unwrap();
self.result.extend(json.as_bytes());
self.result.push(b'\n');
return Ok(self.read_from_result(buf));
} else {
result.insert(
token.to_string(),
PredResult {
p: token_prob,
is_end_token,
},
);
let json = serde_json::to_string(&result).unwrap();
self.result.extend(json.as_bytes());
self.result.push(b'\n');
return Ok(self.read_from_result(buf));
}
}
}
#[cfg(feature = "server")]
struct AttentionCacheRepository {
caches: BTreeMap<Vec<TokenId>, (TransformerCaches, usize, std::time::Instant)>,
max_sz: usize,
}
#[cfg(feature = "server")]
impl AttentionCacheRepository {
fn empty(max_size: usize) -> AttentionCacheRepository {
AttentionCacheRepository {
caches: BTreeMap::new(),
max_sz: max_size,
}
}
/// Makes sure the cache repository is not larger than sz, evicts any older items.
fn limit_size(&mut self, sz: usize) {
if sz == 0 {
self.caches = BTreeMap::new();
return;
}
// Slow algorithm but I guess our cache will never be unimaginably large so it's probably
// fine
while self.caches.len() > sz {
let mut oldest_time = None;
let mut oldest_key: Option<&Vec<TokenId>> = None;
for (k, (_, _, time)) in self.caches.iter() {
if oldest_time.is_none() || time < oldest_time.unwrap() {
oldest_time = Some(time);
oldest_key = Some(k);
}
}
let oldest_key = oldest_key.unwrap().clone();
self.caches.remove(&oldest_key);
}
}
fn get(&self, tokens: &[TokenId]) -> Option<(&TransformerCaches, usize)> {
if let Some((caches, pos, _)) = self.caches.get(tokens) {
Some((caches, *pos))
} else {
None
}
}
fn put(&mut self, tokens: Vec<TokenId>, caches: TransformerCaches, prev_pos: usize) {
self.caches
.insert(tokens, (caches, prev_pos, std::time::Instant::now()));
self.limit_size(self.max_sz);
}
}
#[cfg(feature = "server")]
#[derive(Clone)]
struct InferenceServerState {
transformer: Arc<Transformer>,
tokenizer: Arc<Tokenizer>,
max_seq_len: usize,
concurrent_requests_semaphore: Semaphore,
attention_cache_repository: Arc<RwLock<AttentionCacheRepository>>,
}
#[cfg(feature = "server")]
#[post("/", data = "<input>")]
fn handle_request(
state: State<InferenceServerState>,
input: Data,
) -> Result<Stream<GeneratingSession>, status::BadRequest<String>> {
let _lock = state.concurrent_requests_semaphore.acquire();
let tr = state.transformer.clone();
let tok = state.tokenizer.clone();
let mut data = input.open();
let mut databuf: Vec<u8> = Vec::new();
data.read_to_end(&mut databuf).unwrap();
// Parse the JSON out of the request
let request: InferenceRequest = match serde_json::from_slice(&databuf) {
Err(_e) => {
return Err(status::BadRequest(Some("Invalid JSON.".to_string())));
}
Ok(ir) => ir,
};
let stop_at_end_token = request.stop_at_end_token.unwrap_or(true);
let temperature = request.temperature.unwrap_or(1.0);
let top_k = request.top_k.unwrap_or(20);
let top_p = request.top_p.unwrap_or(1.0);
let repetition_penalty = request.repetition_penalty.unwrap_or(1.0);
let mut req_max_seq_len = request.max_seq_len.unwrap_or(state.max_seq_len);
if req_max_seq_len > state.max_seq_len {
req_max_seq_len = state.max_seq_len;
}
let req_max_new_tokens = request.max_new_tokens.unwrap_or(20);
let no_token_sampling = request.no_token_sampling.unwrap_or(false);
let prompt = request.prompt;
if temperature.is_nan() {
return Err(status::BadRequest(Some(
"Temperature must be a number.".to_string(),
)));
}
if top_k == 0 {
return Err(status::BadRequest(Some(
"Top-k must be greater than 0.".to_string(),
)));
}
if top_p.is_nan() {
return Err(status::BadRequest(Some(
"Top-p must be a number.".to_string(),
)));
}
if repetition_penalty.is_nan() {
return Err(status::BadRequest(Some(
"Repetition penalty must be a number.".to_string(),
)));
}
let token_sampler = TokenSampler::new()
.temperature(temperature)
.top_p(top_p)
.top_k(top_k)
.repetition_penalty(repetition_penalty);
let toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let gsession = GeneratingSession {
transformer: tr,
tokenizer: tok,
attention_cache_repository: state.attention_cache_repository.clone(),
token_sampler: token_sampler,
tokens: toks_id,
req_max_seq_len: req_max_seq_len,
req_max_new_tokens: req_max_new_tokens,
new_tokens_generated: 0,
prev_pos: 0,
no_token_sampling: no_token_sampling,
stop_at_end_token: stop_at_end_token,
sent_stuff_last_time: false,
result: Vec::new(),
};
return Ok(rocket::response::Stream::chunked(gsession, 1024));
}
fn command_line_inference(
cli: Cli,
tr: Arc<Transformer>,
tok: Arc<Tokenizer>,
prompt: String,
be_quiet: bool,
max_seq_len: usize,
params: ModelParams,
max_threads: usize,
) -> Result<(), Box<dyn std::error::Error>> {
// Custom println-like macro that respects be_quiet
macro_rules! pln {
($($arg:tt)*) => {
if !be_quiet {
std::println!($($arg)*);
}
};
}
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone()); let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let mut prev_pos = 0; let mut prev_pos = 0;
let mut token_sampler = TokenSampler::new() let mut token_sampler = TokenSampler::new()
.temperature(0.8) .temperature(1.0)
.top_p(0.9) .top_p(1.0)
.top_k(50) .top_k(20)
.repetition_penalty(0.8); .repetition_penalty(1.0);
if let Some(temperature) = cli.temperature { if let Some(temperature) = cli.temperature {
token_sampler = token_sampler.temperature(temperature); token_sampler = token_sampler.temperature(temperature);

@ -23,6 +23,7 @@ use crate::tensor_opencl_support::{OpenCL, OpenCLError, OpenCLEvent, OpenCLTenso
use crate::unpickler; use crate::unpickler;
use crate::unpickler::UnpicklingError; use crate::unpickler::UnpicklingError;
use half::f16; use half::f16;
use lazy_static::lazy_static;
use rand::Rng; use rand::Rng;
use rayon::prelude::*; use rayon::prelude::*;
use std::alloc::Layout; use std::alloc::Layout;
@ -123,12 +124,21 @@ impl Clone for Tensor {
} }
} }
// Tracks how many bytes are allocated for tensors globally on CPU.
// I've used this to debug memory leaks and monitor memory usage.
lazy_static! {
static ref TENSORS_BYTES_ALLOCATED: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
}
impl Drop for Tensor { impl Drop for Tensor {
fn drop(&mut self) { fn drop(&mut self) {
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
self.process_waiting_for_data_mut(); self.process_waiting_for_data_mut();
unsafe { unsafe {
if !self.data.is_null() { if !self.data.is_null() {
TENSORS_BYTES_ALLOCATED
.fetch_sub(self.layout.size(), std::sync::atomic::Ordering::Relaxed);
std::alloc::dealloc(self.data, self.layout); std::alloc::dealloc(self.data, self.layout);
} }
} }
@ -342,6 +352,7 @@ impl Tensor {
if data.is_null() { if data.is_null() {
panic!("Failed to allocate tensor"); panic!("Failed to allocate tensor");
} }
TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed);
// Even though we are uninitialized, we should zero out the extra space between the // Even though we are uninitialized, we should zero out the extra space between the
// columns. // columns.
// Otherwise there might be problems later as other operations assume it is zeroed. // Otherwise there might be problems later as other operations assume it is zeroed.
@ -1384,12 +1395,14 @@ impl Tensor {
as *const I16x8, as *const I16x8,
), ),
) )
} else { } else if row < nrows {
( (
load_i16x8(ptr.add(row * cols_capacity + column) load_i16x8(ptr.add(row * cols_capacity + column)
as *const I16x8), as *const I16x8),
i16x8_zero(), i16x8_zero(),
) )
} else {
(i16x8_zero(), i16x8_zero())
}; };
let left: F32x8 = i16x8_as_f16_to_f32x8(left); let left: F32x8 = i16x8_as_f16_to_f32x8(left);
let right: F32x8 = i16x8_as_f16_to_f32x8(right); let right: F32x8 = i16x8_as_f16_to_f32x8(right);
@ -1840,6 +1853,7 @@ impl Tensor {
if data.is_null() { if data.is_null() {
panic!("Failed to allocate tensor"); panic!("Failed to allocate tensor");
} }
TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed);
Self { Self {
data, data,
#[cfg(feature = "opencl")] #[cfg(feature = "opencl")]
@ -2005,6 +2019,7 @@ impl Tensor {
if data.is_null() { if data.is_null() {
panic!("to_cpu_inplace: Failed to allocate tensor"); panic!("to_cpu_inplace: Failed to allocate tensor");
} }
TENSORS_BYTES_ALLOCATED.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed);
let ev = od.as_mut().unwrap().data_u16_from_gpu(data as *mut u16)?; let ev = od.as_mut().unwrap().data_u16_from_gpu(data as *mut u16)?;
self.data = data as *mut u16 as *mut u8; self.data = data as *mut u16 as *mut u8;
self.waiting_for_data = Some(ev); self.waiting_for_data = Some(ev);

@ -65,6 +65,21 @@ impl TokenSampler {
} }
} }
pub fn logits_to_btreemap(
&self,
logits: &Tensor,
tokenizer: &Tokenizer,
) -> BTreeMap<String, f32> {
let mut result = BTreeMap::new();
for token_idx in 0..logits.rows() {
result.insert(
tokenizer.id_to_str(token_idx as TokenId).to_string(),
logits.get_f32(token_idx, 0),
);
}
result
}
pub fn sample( pub fn sample(
&self, &self,
logits: &Tensor, logits: &Tensor,

@ -28,6 +28,8 @@ pub struct Transformer {
output: Tensor, output: Tensor,
layers: Vec<TransformerBlock>, layers: Vec<TransformerBlock>,
data_settings: DataSettings,
} }
// Clone is cheap // Clone is cheap
@ -94,25 +96,59 @@ pub struct TransformerBlock {
pub struct AttentionCache { pub struct AttentionCache {
cache_k: Vec<Arc<RwLock<Tensor>>>, cache_k: Vec<Arc<RwLock<Tensor>>>,
cache_v: Vec<Arc<RwLock<Tensor>>>, cache_v: Vec<Arc<RwLock<Tensor>>>,
data_settings: DataSettings,
} }
impl AttentionCache { impl AttentionCache {
fn new(max_seq_len: usize, n_local_heads: usize, head_dim: usize) -> Self { fn new(
max_seq_len: usize,
n_local_heads: usize,
head_dim: usize,
data_settings: &DataSettings,
) -> Self {
let mut cache_k = Vec::with_capacity(n_local_heads); let mut cache_k = Vec::with_capacity(n_local_heads);
let mut cache_v = Vec::with_capacity(n_local_heads); let mut cache_v = Vec::with_capacity(n_local_heads);
let dtype = if data_settings.force_f16 {
TensorDType::Float16
} else {
TensorDType::Float32
};
for _ in 0..n_local_heads { for _ in 0..n_local_heads {
cache_k.push(Arc::new(RwLock::new(Tensor::zeros( cache_k.push(Arc::new(RwLock::new(Tensor::zeros(
head_dim as i64, head_dim as i64,
max_seq_len as i64, max_seq_len as i64,
TensorDType::Float32, dtype,
)))); ))));
cache_v.push(Arc::new(RwLock::new(Tensor::zeros( cache_v.push(Arc::new(RwLock::new(Tensor::zeros(
head_dim as i64, head_dim as i64,
max_seq_len as i64, max_seq_len as i64,
TensorDType::Float32, dtype,
)))); ))));
} }
AttentionCache { cache_k, cache_v } AttentionCache {
cache_k,
cache_v,
data_settings: data_settings.clone(),
}
}
/// Cloning AttentionCache normally just makes new references to the same cache.
/// This creates a true clone with copied tensors.
fn true_clone(&self) -> AttentionCache {
let mut cache_k = Vec::with_capacity(self.cache_k.len());
let mut cache_v = Vec::with_capacity(self.cache_v.len());
for idx in 0..self.cache_k.len() {
let old_k = self.cache_k[idx].read().unwrap();
cache_k.push(Arc::new(RwLock::new(old_k.clone())));
let old_v = self.cache_v[idx].read().unwrap();
cache_v.push(Arc::new(RwLock::new(old_v.clone())));
}
AttentionCache {
cache_k,
cache_v,
data_settings: self.data_settings.clone(),
}
} }
fn shift_left(&mut self, shifts: usize) { fn shift_left(&mut self, shifts: usize) {
@ -141,6 +177,14 @@ impl TransformerCaches {
layer.shift_left(shifts); layer.shift_left(shifts);
} }
} }
pub fn true_clone(&self) -> TransformerCaches {
let mut layer_caches = Vec::with_capacity(self.layer_caches.len());
for layer in self.layer_caches.iter() {
layer_caches.push(layer.true_clone());
}
TransformerCaches { layer_caches }
}
} }
pub struct RMSNorm { pub struct RMSNorm {
@ -218,6 +262,7 @@ impl Transformer {
Ok(Transformer { Ok(Transformer {
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0), freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0),
data_settings: data_settings.clone(),
emb, emb,
dim, dim,
n_layers, n_layers,
@ -240,6 +285,7 @@ impl Transformer {
self.max_seq_len, self.max_seq_len,
self.n_local_heads, self.n_local_heads,
self.head_dim, self.head_dim,
&self.data_settings,
)); ));
} }
TransformerCaches { TransformerCaches {
@ -664,6 +710,9 @@ impl Attention {
let keys = cache_k.clip_cols(start_pos + seq_len as usize); let keys = cache_k.clip_cols(start_pos + seq_len as usize);
let values = cache_v.clip_cols(start_pos + seq_len as usize); let values = cache_v.clip_cols(start_pos + seq_len as usize);
let keys = keys.into_same_type(&xq_row);
let values = values.into_same_type(&xq_row);
let m = xq_row let m = xq_row
.matrix_mul(&keys) .matrix_mul(&keys)
.scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt()); .scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt());

Loading…
Cancel
Save