commit 3b8f904f13f3051274479e38a198219367882de3 Author: Mikko Juola Date: Sat Mar 11 00:31:40 2023 -0800 First commit. LLaMA works now. It is not pretty but it does generate text from prompts. Yay. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..4e256e8 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,1155 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anyhow" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bumpalo" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cc" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c137568cc60b904a7724001b35ce2630fd00d5d84805fbb608ab89509d788f" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346de753af073cc87b52b2083a506b38ac176a44cfb05497b622e27be899b369" + +[[package]] +name = "ciborium-ll" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213030a2b5a4e0c0892b6652260cf6ccac84827b83a85a534e178e3906c4cf1b" +dependencies = [ + "ciborium-io", + "half 1.8.2", +] + +[[package]] +name = "clap" +version = "3.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5" +dependencies = [ + "bitflags", + "clap_lex 0.2.4", + "indexmap", + "textwrap", +] + +[[package]] +name = "clap" +version = "4.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" +dependencies = [ + "bitflags", + "clap_derive", + "clap_lex 0.3.2", + "is-terminal", + "once_cell", + "strsim", + "termcolor", +] + +[[package]] +name = "clap_derive" +version = "4.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" +dependencies = [ + "os_str_bytes", +] + +[[package]] +name = "clap_lex" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" +dependencies = [ + "os_str_bytes", +] + +[[package]] +name = "console" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d79fbe8970a77e3e34151cc13d3b3e248aa0faaecb9f6091fa07ebefe5ad60" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.42.0", +] + +[[package]] +name = "criterion" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb" +dependencies = [ + "anes", + "atty", + "cast", + "ciborium", + "clap 3.2.23", + "criterion-plot", + "itertools", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "either" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" + +[[package]] +name = "embedded-profiling" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23294c92851bb57c68a88c4f1546c848eefd3eaeeff5f7a79922abf53afa22b2" +dependencies = [ + "fugit", +] + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "errno" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +dependencies = [ + "errno-dragonfly", + "libc", + "winapi", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + +[[package]] +name = "fugit" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ab17bb279def6720d058cb6c052249938e7f99260ab534879281a95367a87e5" +dependencies = [ + "gcd", +] + +[[package]] +name = "gcd" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" + +[[package]] +name = "getrandom" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + +[[package]] +name = "half" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +dependencies = [ + "crunchy", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + +[[package]] +name = "indexmap" +version = "1.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" +dependencies = [ + "autocfg", + "hashbrown", +] + +[[package]] +name = "indicatif" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cef509aa9bc73864d6756f0d34d35504af3cf0844373afe9b8669a5b8005a729" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3" +dependencies = [ + "libc", + "windows-sys 0.45.0", +] + +[[package]] +name = "is-terminal" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +dependencies = [ + "hermit-abi 0.3.1", + "io-lifetimes", + "rustix", + "windows-sys 0.45.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "js-sys" +version = "0.3.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.139" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" + +[[package]] +name = "linux-raw-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" + +[[package]] +name = "log" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + +[[package]] +name = "memoffset" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +dependencies = [ + "hermit-abi 0.2.6", + "libc", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "once_cell" +version = "1.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" + +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "os_str_bytes" +version = "6.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" + +[[package]] +name = "plotters" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2538b639e642295546c50fcd545198c9d64ee2a38620a628724a3b266d5fbf97" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "193228616381fecdc1224c62e96946dfbc73ff4384fba576e052ff8c1bea8142" + +[[package]] +name = "plotters-svg" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9a81d2759aae1dae668f783c308bc5c8ebd191ff4184aaa1b37f65a6ae5a56f" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "portable-atomic" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "protobuf" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55bad9126f378a853655831eb7363b7b01b81d19f8cb1218861086ca4a1a61e" +dependencies = [ + "once_cell", + "protobuf-support", + "thiserror", +] + +[[package]] +name = "protobuf-codegen" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd418ac3c91caa4032d37cb80ff0d44e2ebe637b2fb243b6234bf89cdac4901" +dependencies = [ + "anyhow", + "once_cell", + "protobuf", + "protobuf-parse", + "regex", + "tempfile", + "thiserror", +] + +[[package]] +name = "protobuf-parse" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d39b14605eaa1f6a340aec7f320b34064feb26c93aec35d6a9a2272a8ddfa49" +dependencies = [ + "anyhow", + "indexmap", + "log", + "protobuf", + "protobuf-support", + "tempfile", + "thiserror", + "which", +] + +[[package]] +name = "protobuf-support" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d4d7b8601c814cfb36bcebb79f0e61e45e1e93640cf778837833bbed05c372" +dependencies = [ + "thiserror", +] + +[[package]] +name = "quote" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" + +[[package]] +name = "rllama" +version = "0.1.0" +dependencies = [ + "approx", + "clap 4.1.8", + "criterion", + "embedded-profiling", + "half 2.2.1", + "indicatif", + "num-complex", + "protobuf", + "protobuf-codegen", + "protobuf-parse", + "rand", + "rayon", + "thiserror", +] + +[[package]] +name = "rustix" +version = "0.36.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43abb88211988493c1abb44a70efa56ff0ce98f233b7b276146f1f3f7ba9644" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.45.0", +] + +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "serde" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix", + "windows-sys 0.42.0", +] + +[[package]] +name = "termcolor" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "textwrap" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" + +[[package]] +name = "thiserror" +version = "1.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "unicode-ident" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" + +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "walkdir" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +dependencies = [ + "same-file", + "winapi", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" + +[[package]] +name = "web-sys" +version = "0.3.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "which" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" +dependencies = [ + "either", + "libc", + "once_cell", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..9f24c01 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "rllama" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "rllama" +path = "src/main.rs" + +[dependencies] +protobuf = "3.2" +thiserror = "1.0" +half = "2.2" +num-complex = "0.4" +embedded-profiling = "0.3" +rand = "0.8" +approx = "0.5" +rayon = "1.7" +clap = { version = "4.1", features = ["derive"] } +indicatif = "0.17" + +# We need protobuf compiler +[build-dependencies] +protobuf-codegen = "3.2" +protobuf-parse = "3.2" + +[dev-dependencies] +criterion = "0.4" + +[profile.release] +debug = true + +[[bench]] +path = "src/benches/benchmark.rs" +name = "benchmark" +harness = false diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2beb9e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,662 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. + diff --git a/LICENSE.third_parties b/LICENSE.third_parties new file mode 100644 index 0000000..6a95af5 --- /dev/null +++ b/LICENSE.third_parties @@ -0,0 +1,208 @@ +proto/ directory contains a protobuf file from Google's +https://github.com/google/sentencepiece repository. + +Here is their license: (note rllama as a whole is AGPL3) +----- + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c3e2f31 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# AdeonLLaMA + +This is my attempt at making the LLaMA language model working on a pure Rust +CPU implementation. + +As of writing of this, it can run LLaMA-7B at around ~1 token per second, using +something like 1.5 threads because I haven't yet properly figured out how to +multithread this. + +It uses AVX2 intrinsics to speed up itself. + +# How to run + +You will need the LLaMA-7B weights first. Refer to https://github.com/facebookresearch/llama/ + +Once you have 7B weights, and the `tokenizer.model` it comes with, you can make +it generate tokens: + +```shell +cargo run --release -- --tokenizer-model /path/to/tokenizer.model --model-path /path/to/LLaMA/7B +``` diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..6309499 --- /dev/null +++ b/build.rs @@ -0,0 +1,9 @@ +fn main() { + protobuf_codegen::Codegen::new() + .pure() + .out_dir("src/protomodels") + .include("proto") + .input("proto/sentencepiece_model.proto") + .run() + .unwrap(); +} diff --git a/proto/sentencepiece_model.proto b/proto/sentencepiece_model.proto new file mode 100644 index 0000000..b6c1224 --- /dev/null +++ b/proto/sentencepiece_model.proto @@ -0,0 +1,321 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +syntax = "proto2"; + +// TODO(taku): Needs to use LITE RUNTIME in OSS release. +option optimize_for = LITE_RUNTIME; + +package sentencepiece; + +// TrainerSpec encodes a various parameters for SentencePiece training. +// Next id: 53 +message TrainerSpec { + /////////////////////////////////////////////////////////////////// + // General parameters + // + // Input corpus files. + // Trainer accepts the following two formats: + // A) Monolingual: plain text, one sentence per line. + // B) Bilingual: TSV, source sentence target sentence + // When bilingual data is passed, shared vocabulary model is built. + // Note that the input file must be raw corpus, not a preprocessed corpus. + // Trainer only loads the first `input_sentence_size` sentences specified + // with this parameter. + repeated string input = 1; + + // Input corpus format: + // "text": one-sentence-per-line text format (default) + // "tsv": sentence freq + optional string input_format = 7; + + // Output model file prefix. + // .model and .vocab are generated. + optional string model_prefix = 2; + + // Model type. only have UNIGRAM now. + enum ModelType { + UNIGRAM = 1; // Unigram language model with dynamic algorithm + BPE = 2; // Byte Pair Encoding + WORD = 3; // Delimitered by whitespace. + CHAR = 4; // tokenizes into character sequence + } + optional ModelType model_type = 3 [default = UNIGRAM]; + + // Vocabulary size. 8k is the default size. + optional int32 vocab_size = 4 [default = 8000]; + + // List of the languages this model can accept. + // Since the model is language-agnostic, this field is used as a reference. + repeated string accept_language = 5; + + // Size of self-test samples, which are encoded in the model file. + optional int32 self_test_sample_size = 6 [default = 0]; + + // Whether to use DP version of sentencepiece. Use it with TSV input format + // (requires precomputed word tab counts to work). + optional bool enable_differential_privacy = 50 [default = false]; + // Set these parameters if you need DP version of sentencepiece. + // std of noise to add. + optional float differential_privacy_noise_level = 51 [default = 0.0]; + // Clipping threshold to apply after adding noise. All the words with + // frequency less than this value are dropped. + optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; + + /////////////////////////////////////////////////////////////////// + // Training parameters. + // + // Uses characters which cover the corpus with the ratio of `chars_coverage`. + // This parameter determines the set of basic Alphabet of sentence piece. + // 1.0 - `chars_coverage` characters are treated as UNK. + // See also required_chars field. + optional float character_coverage = 10 [default = 0.9995]; + + // Maximum size of sentences the trainer loads from `input` parameter. + // Trainer simply loads the `input` files in sequence. + // It is better to shuffle the input corpus randomly. + optional uint64 input_sentence_size = 11 [default = 0]; + optional bool shuffle_input_sentence = 19 [default = true]; + + // Maximum size of sentences to make seed sentence pieces. + // Extended suffix array is constructed to extract frequent + // sub-strings from the corpus. This uses 20N working space, + // where N is the size of corpus. + optional int32 mining_sentence_size = 12 [deprecated = true]; + + // Maximum size of sentences to train sentence pieces. + optional int32 training_sentence_size = 13 [deprecated = true]; + + // The size of seed sentencepieces. + // `seed_sentencepiece_size` must be larger than `vocab_size`. + optional int32 seed_sentencepiece_size = 14 [default = 1000000]; + + // In every EM sub-iterations, keeps top + // `shrinking_factor` * `current sentencepieces size` with respect to + // the loss of the sentence piece. This value should be smaller than 1.0. + optional float shrinking_factor = 15 [default = 0.75]; + + // The maximum sentence length in byte. The sentences with the length + // larger than `max_sentence_length` is simply ignored. + // Longer input tends to bring the following risks: + // * Overflow during EM training (unigram language model only) + // * Performance drop because of O(n log n) cost in BPE. + optional int32 max_sentence_length = 18 [default = 4192]; + + // Number of threads in the training. + optional int32 num_threads = 16 [default = 16]; + + // Number of EM sub iterations. + optional int32 num_sub_iterations = 17 [default = 2]; + + /////////////////////////////////////////////////////////////////// + // SentencePiece parameters which control the shapes of sentence piece. + // + // Maximum length of sentencepiece. + optional int32 max_sentencepiece_length = 20 [default = 16]; + + // Uses Unicode script to split sentence pieces. + // When `split_by_unicode_script` is true, we do not allow sentence piece to + // include multiple Unicode scripts, e.g. "F1" is not a valid piece. + // Exception: CJ characters (Hiragana/Katakana/Han) are all handled + // as one script type, since Japanese word can consist of multiple scripts. + // This exception is always applied regardless of the accept-language + // parameter. + optional bool split_by_unicode_script = 21 [default = true]; + + // When `split_by_number` is true, put a boundary between number and + // non-number transition. If we want to treat "F1" is one token, set this flag + // to be false. + optional bool split_by_number = 23 [default = true]; + + // Use a white space to split sentence pieces. + // When `split_by_whitespace` is false, we may have the piece containing + // a white space in the middle. e.g., "in_the". + optional bool split_by_whitespace = 22 [default = true]; + + // Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello => + // hello_. When `treat_whitespace_as_suffix` is true, + // NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end + // of sentence. + optional bool treat_whitespace_as_suffix = 24 [default = false]; + + // Allows pieces that only contain whitespaces instead of appearing only as + // prefix or suffix of other pieces. + optional bool allow_whitespace_only_pieces = 26 [default = false]; + + // Split all digits (0-9) into separate pieces. + optional bool split_digits = 25 [default = false]; + + /////////////////////////////////////////////////////////////////// + // Vocabulary management + // + // Defines control symbols used as an indicator to + // change the behavior of the decoder. and are pre-defined. + // We can use this field to encode various meta information, + // including language indicator in multilingual model. + // These symbols are not visible to users, but visible to + // the decoder. Note that when the input sentence contains control symbols, + // they are not treated as one token, but segmented into normal pieces. + // Control symbols must be inserted independently from the segmentation. + repeated string control_symbols = 30; + + // Defines user defined symbols. + // These symbols are added with extremely high score + // so they are always treated as one unique symbol in any context. + // Typical usage of user_defined_symbols is placeholder for named entities. + repeated string user_defined_symbols = 31; + + // Defines required characters. Each UTF8 character in this string is included + // in the character set regardless of character_coverage value. Unlike + // user_defined_symbols, these characters have scores based on the frequency + // on input sentences, and the model can form subwords using characters + // in this field. + optional string required_chars = 36; + + // Decomposes unknown pieces into UTF-8 bytes. + optional bool byte_fallback = 35 [default = false]; + + // When creating the vocabulary file, defines whether or not to additionally + // output the score for each piece. + optional bool vocabulary_output_piece_score = 32 [default = true]; + + // `vocab_size` is treated as hard limit. Crash if + // the model can not produce the vocab of size `vocab_size`, + // When `hard_vocab_limit` is false, vocab_size is treated + // as soft limit. Note that when model_type=char, + // always assumes hard_vocab_limit = false. + optional bool hard_vocab_limit = 33 [default = true]; + + // use all symbols for vocab extraction. This flag is valid + // if model type is either CHAR or WORD + optional bool use_all_vocab = 34 [default = false]; + + /////////////////////////////////////////////////////////////////// + // Reserved special meta tokens. + // * -1 is not used. + // * unk_id must not be -1. + // Id must starts with 0 and be contigous. + optional int32 unk_id = 40 [default = 0]; // + optional int32 bos_id = 41 [default = 1]; // + optional int32 eos_id = 42 [default = 2]; // + optional int32 pad_id = 43 [default = -1]; // (padding) + optional string unk_piece = 45 [default = ""]; + optional string bos_piece = 46 [default = ""]; + optional string eos_piece = 47 [default = ""]; + optional string pad_piece = 48 [default = ""]; + + // Encodes into U+2047 (DOUBLE QUESTION MARK), + // since this character can be useful both for user and + // developer. We can easily figure out that is emitted. + optional string unk_surface = 44 [default = " \xE2\x81\x87 "]; + + // Increase bit depth to allow unigram model training on large + // (>10M sentences) corpora. A Side-effect of enabling this flag + // is increased memory usage. + optional bool train_extremely_large_corpus = 49 [default = false]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// NormalizerSpec encodes a various parameters for string normalizaiton +message NormalizerSpec { + // name of normalization rule. + optional string name = 1; + + // Pre-compiled normalization rule created by + // Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method. + // Usually this field is set by Builder::GetNormalizerSpec() method. + optional bytes precompiled_charsmap = 2; + + // Adds dummy whitespace at the beginning of text in order to + // treat "world" in "world" and "hello world" in the same way. + optional bool add_dummy_prefix = 3 [default = true]; + + // Removes leading, trailing, and duplicate internal whitespace. + optional bool remove_extra_whitespaces = 4 [default = true]; + + // Replaces whitespace with meta symbol. + // This field must be true to train sentence piece model. + optional bool escape_whitespaces = 5 [default = true]; + + // Custom normalization rule file in TSV format. + // https://github.com/google/sentencepiece/blob/master/doc/normalization.md + // This field is only used in SentencePieceTrainer::Train() method, which + // compiles the rule into the binary rule stored in `precompiled_charsmap`. + optional string normalization_rule_tsv = 6; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// Proto to store samples for self-testing. +message SelfTestData { + message Sample { + optional string input = 1; + optional string expected = 2; + } + repeated Sample samples = 1; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// ModelProto stores model parameters. +// SentencePieceProcessor is supposed to be self-contained. +// All settings/parameters which may change the behavior must be encoded +// in ModelProto. +message ModelProto { + message SentencePiece { + enum Type { + NORMAL = 1; // normal symbol + UNKNOWN = 2; // unknown symbol. only for now. + CONTROL = 3; // control symbols. , , <2ja> etc. + USER_DEFINED = 4; // user defined symbols. + // Typical usage of USER_DEFINED symbol + // is placeholder. + BYTE = 6; // byte symbols. Used when `byte_fallback` is true. + UNUSED = 5; // this piece is not used. + } + optional string piece = 1; // piece must not be empty. + optional float score = 2; + optional Type type = 3 [default = NORMAL]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; + } + + // Sentence pieces with scores. + repeated SentencePiece pieces = 1; + + // Spec used to generate this model file. + optional TrainerSpec trainer_spec = 2; + + // Spec for text normalization. + optional NormalizerSpec normalizer_spec = 3; + + // Stores sample input and its expected segmentation to verify the model. + optional SelfTestData self_test_data = 4; + + // Spec for text de-normalization. + optional NormalizerSpec denormalizer_spec = 5; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} diff --git a/src/benches/benchmark.rs b/src/benches/benchmark.rs new file mode 100644 index 0000000..a92c472 --- /dev/null +++ b/src/benches/benchmark.rs @@ -0,0 +1,85 @@ +extern crate rllama; + +use rllama::tensor::{Tensor, TensorDType}; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn tensor_benchmarks(c: &mut Criterion) { + let orig16_1 = Tensor::full(16, 32, TensorDType::Float16, 3.0); + let orig16_2 = Tensor::full(32, 512, TensorDType::Float16, -1.33); + + let orig32_1 = Tensor::full(16, 32, TensorDType::Float32, 3.0); + let orig32_2 = Tensor::full(32, 512, TensorDType::Float32, -1.33); + let orig32_2_transposed = orig32_2.transpose(); + + let mut result_16 = Tensor::zeros(16, 512, TensorDType::Float16); + let mut result_32 = Tensor::zeros(16, 512, TensorDType::Float32); + + let orig_84096_1 = Tensor::zeros(8, 4096, TensorDType::Float32); + let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32); + let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32); + + c.bench_function( + "matrix multiplication 8x4096 @ 4096x4096 f32 in-place", + |b| { + b.iter(|| { + let _ = result_84096 + .matrix_mul_inplace(black_box(&orig_84096_1), black_box(&orig_84096_2)); + }) + }, + ); + + c.bench_function( + "matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed", + |b| { + b.iter(|| { + let _ = result_84096.matrix_mul_inplace_transposed( + black_box(&orig_84096_1), + black_box(&orig_84096_2), + ); + }) + }, + ); + + c.bench_function("matrix multiplication f32 not in-place", |b| { + b.iter(|| { + let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2)); + }) + }); + c.bench_function("matrix multiplication f32 naive", |b| { + b.iter(|| { + let _ = black_box(&orig32_1).matrix_mul_naive(black_box(&orig32_2)); + }) + }); + c.bench_function("matrix multiplication f16 not in-place", |b| { + b.iter(|| { + let _ = black_box(&orig16_1).matrix_mul(black_box(&orig16_2)); + }) + }); + c.bench_function("matrix multiplication f16 naive", |b| { + b.iter(|| { + let _ = black_box(&orig16_1).matrix_mul_naive(black_box(&orig16_2)); + }) + }); + c.bench_function("matrix multiplication f16 in-place", |b| { + b.iter(|| { + let _ = result_16.matrix_mul_inplace(black_box(&orig16_1), black_box(&orig16_2)); + }) + }); + c.bench_function("matrix multiplication f32 in-place", |b| { + b.iter(|| { + let _ = result_32.matrix_mul_inplace(black_box(&orig32_1), black_box(&orig32_2)); + }) + }); + c.bench_function("matrix multiplication f32 in-place, transposed", |b| { + b.iter(|| { + let _ = result_32.matrix_mul_inplace_transposed( + black_box(&orig32_1), + black_box(&orig32_2_transposed), + ); + }) + }); +} + +criterion_group!(benches, tensor_benchmarks); +criterion_main!(benches); diff --git a/src/embedding.rs b/src/embedding.rs new file mode 100644 index 0000000..dccbf6e --- /dev/null +++ b/src/embedding.rs @@ -0,0 +1,45 @@ +use crate::tensor::Tensor; +use crate::unpickler; +use crate::unpickler::*; +use std::collections::BTreeMap; +use std::path::Path; + +pub struct Embedding { + wgts: BTreeMap, +} + +impl Embedding { + pub fn from_unpickled>( + unpickled: &unpickler::Value, + data_dir: P, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + + let val = match unpickled.get_str_key("tok_embeddings.weight") { + Some(val) => val, + None => { + return Err(UnpicklingError::MissingField( + "tok_embeddings.weight".to_string(), + )) + } + }; + let tensor = val + .to_tensor_builder() + .ok_or(UnpicklingError::InvalidTensorData)?; + let tensor = tensor.load(data_dir)?; + + let num_embeddings = tensor.rows(); + + let mut table: BTreeMap = BTreeMap::new(); + for key in 0..num_embeddings { + let row = tensor.row(key); + table.insert(key as usize, row); + } + + Ok(Self { wgts: table }) + } + + pub fn get_embedding(&self, idx: usize) -> &Tensor { + self.wgts.get(&idx).unwrap() + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..4ce1c6b --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,10 @@ +#![feature(stdsimd)] + +pub mod embedding; +pub mod protomodels; +pub mod rllama_main; +pub mod tensor; +pub mod token_sampler; +pub mod tokenizer; +pub mod transformer; +pub mod unpickler; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..1fe1682 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,3 @@ +pub fn main() -> Result<(), Box> { + rllama::rllama_main::main() +} diff --git a/src/protomodels/mod.rs b/src/protomodels/mod.rs new file mode 100644 index 0000000..ad74275 --- /dev/null +++ b/src/protomodels/mod.rs @@ -0,0 +1,3 @@ +// @generated + +pub mod sentencepiece_model; diff --git a/src/protomodels/sentencepiece_model.rs b/src/protomodels/sentencepiece_model.rs new file mode 100644 index 0000000..294fb20 --- /dev/null +++ b/src/protomodels/sentencepiece_model.rs @@ -0,0 +1,2643 @@ +// This file is generated by rust-protobuf 3.2.0. Do not edit +// .proto file is parsed by pure +// @generated + +// https://github.com/rust-lang/rust-clippy/issues/702 +#![allow(unknown_lints)] +#![allow(clippy::all)] + +#![allow(unused_attributes)] +#![cfg_attr(rustfmt, rustfmt::skip)] + +#![allow(box_pointers)] +#![allow(dead_code)] +#![allow(missing_docs)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(trivial_casts)] +#![allow(unused_results)] +#![allow(unused_mut)] + +//! Generated file from `sentencepiece_model.proto` +// Generated for lite runtime + +/// Generated files are compatible only with the same version +/// of protobuf runtime. +const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_3_2_0; + +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:sentencepiece.TrainerSpec) +pub struct TrainerSpec { + // message fields + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.input) + pub input: ::std::vec::Vec<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.input_format) + pub input_format: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.model_prefix) + pub model_prefix: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.model_type) + pub model_type: ::std::option::Option<::protobuf::EnumOrUnknown>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.vocab_size) + pub vocab_size: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.accept_language) + pub accept_language: ::std::vec::Vec<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.self_test_sample_size) + pub self_test_sample_size: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.enable_differential_privacy) + pub enable_differential_privacy: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.differential_privacy_noise_level) + pub differential_privacy_noise_level: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.differential_privacy_clipping_threshold) + pub differential_privacy_clipping_threshold: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.character_coverage) + pub character_coverage: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.input_sentence_size) + pub input_sentence_size: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.shuffle_input_sentence) + pub shuffle_input_sentence: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.mining_sentence_size) + pub mining_sentence_size: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.training_sentence_size) + pub training_sentence_size: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.seed_sentencepiece_size) + pub seed_sentencepiece_size: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.shrinking_factor) + pub shrinking_factor: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.max_sentence_length) + pub max_sentence_length: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.num_threads) + pub num_threads: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.num_sub_iterations) + pub num_sub_iterations: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.max_sentencepiece_length) + pub max_sentencepiece_length: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.split_by_unicode_script) + pub split_by_unicode_script: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.split_by_number) + pub split_by_number: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.split_by_whitespace) + pub split_by_whitespace: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.treat_whitespace_as_suffix) + pub treat_whitespace_as_suffix: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.allow_whitespace_only_pieces) + pub allow_whitespace_only_pieces: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.split_digits) + pub split_digits: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.control_symbols) + pub control_symbols: ::std::vec::Vec<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.user_defined_symbols) + pub user_defined_symbols: ::std::vec::Vec<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.required_chars) + pub required_chars: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.byte_fallback) + pub byte_fallback: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.vocabulary_output_piece_score) + pub vocabulary_output_piece_score: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.hard_vocab_limit) + pub hard_vocab_limit: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.use_all_vocab) + pub use_all_vocab: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.unk_id) + pub unk_id: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.bos_id) + pub bos_id: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.eos_id) + pub eos_id: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.pad_id) + pub pad_id: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.unk_piece) + pub unk_piece: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.bos_piece) + pub bos_piece: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.eos_piece) + pub eos_piece: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.pad_piece) + pub pad_piece: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.unk_surface) + pub unk_surface: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.TrainerSpec.train_extremely_large_corpus) + pub train_extremely_large_corpus: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:sentencepiece.TrainerSpec.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a TrainerSpec { + fn default() -> &'a TrainerSpec { + ::default_instance() + } +} + +impl TrainerSpec { + pub fn new() -> TrainerSpec { + ::std::default::Default::default() + } + + // optional string input_format = 7; + + pub fn input_format(&self) -> &str { + match self.input_format.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_input_format(&mut self) { + self.input_format = ::std::option::Option::None; + } + + pub fn has_input_format(&self) -> bool { + self.input_format.is_some() + } + + // Param is passed by value, moved + pub fn set_input_format(&mut self, v: ::std::string::String) { + self.input_format = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_input_format(&mut self) -> &mut ::std::string::String { + if self.input_format.is_none() { + self.input_format = ::std::option::Option::Some(::std::string::String::new()); + } + self.input_format.as_mut().unwrap() + } + + // Take field + pub fn take_input_format(&mut self) -> ::std::string::String { + self.input_format.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional string model_prefix = 2; + + pub fn model_prefix(&self) -> &str { + match self.model_prefix.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_model_prefix(&mut self) { + self.model_prefix = ::std::option::Option::None; + } + + pub fn has_model_prefix(&self) -> bool { + self.model_prefix.is_some() + } + + // Param is passed by value, moved + pub fn set_model_prefix(&mut self, v: ::std::string::String) { + self.model_prefix = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_model_prefix(&mut self) -> &mut ::std::string::String { + if self.model_prefix.is_none() { + self.model_prefix = ::std::option::Option::Some(::std::string::String::new()); + } + self.model_prefix.as_mut().unwrap() + } + + // Take field + pub fn take_model_prefix(&mut self) -> ::std::string::String { + self.model_prefix.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional .sentencepiece.TrainerSpec.ModelType model_type = 3; + + pub fn model_type(&self) -> trainer_spec::ModelType { + match self.model_type { + Some(e) => e.enum_value_or(trainer_spec::ModelType::UNIGRAM), + None => trainer_spec::ModelType::UNIGRAM, + } + } + + pub fn clear_model_type(&mut self) { + self.model_type = ::std::option::Option::None; + } + + pub fn has_model_type(&self) -> bool { + self.model_type.is_some() + } + + // Param is passed by value, moved + pub fn set_model_type(&mut self, v: trainer_spec::ModelType) { + self.model_type = ::std::option::Option::Some(::protobuf::EnumOrUnknown::new(v)); + } + + // optional int32 vocab_size = 4; + + pub fn vocab_size(&self) -> i32 { + self.vocab_size.unwrap_or(8000i32) + } + + pub fn clear_vocab_size(&mut self) { + self.vocab_size = ::std::option::Option::None; + } + + pub fn has_vocab_size(&self) -> bool { + self.vocab_size.is_some() + } + + // Param is passed by value, moved + pub fn set_vocab_size(&mut self, v: i32) { + self.vocab_size = ::std::option::Option::Some(v); + } + + // optional int32 self_test_sample_size = 6; + + pub fn self_test_sample_size(&self) -> i32 { + self.self_test_sample_size.unwrap_or(0i32) + } + + pub fn clear_self_test_sample_size(&mut self) { + self.self_test_sample_size = ::std::option::Option::None; + } + + pub fn has_self_test_sample_size(&self) -> bool { + self.self_test_sample_size.is_some() + } + + // Param is passed by value, moved + pub fn set_self_test_sample_size(&mut self, v: i32) { + self.self_test_sample_size = ::std::option::Option::Some(v); + } + + // optional bool enable_differential_privacy = 50; + + pub fn enable_differential_privacy(&self) -> bool { + self.enable_differential_privacy.unwrap_or(false) + } + + pub fn clear_enable_differential_privacy(&mut self) { + self.enable_differential_privacy = ::std::option::Option::None; + } + + pub fn has_enable_differential_privacy(&self) -> bool { + self.enable_differential_privacy.is_some() + } + + // Param is passed by value, moved + pub fn set_enable_differential_privacy(&mut self, v: bool) { + self.enable_differential_privacy = ::std::option::Option::Some(v); + } + + // optional float differential_privacy_noise_level = 51; + + pub fn differential_privacy_noise_level(&self) -> f32 { + self.differential_privacy_noise_level.unwrap_or(0.0f32) + } + + pub fn clear_differential_privacy_noise_level(&mut self) { + self.differential_privacy_noise_level = ::std::option::Option::None; + } + + pub fn has_differential_privacy_noise_level(&self) -> bool { + self.differential_privacy_noise_level.is_some() + } + + // Param is passed by value, moved + pub fn set_differential_privacy_noise_level(&mut self, v: f32) { + self.differential_privacy_noise_level = ::std::option::Option::Some(v); + } + + // optional uint64 differential_privacy_clipping_threshold = 52; + + pub fn differential_privacy_clipping_threshold(&self) -> u64 { + self.differential_privacy_clipping_threshold.unwrap_or(0u64) + } + + pub fn clear_differential_privacy_clipping_threshold(&mut self) { + self.differential_privacy_clipping_threshold = ::std::option::Option::None; + } + + pub fn has_differential_privacy_clipping_threshold(&self) -> bool { + self.differential_privacy_clipping_threshold.is_some() + } + + // Param is passed by value, moved + pub fn set_differential_privacy_clipping_threshold(&mut self, v: u64) { + self.differential_privacy_clipping_threshold = ::std::option::Option::Some(v); + } + + // optional float character_coverage = 10; + + pub fn character_coverage(&self) -> f32 { + self.character_coverage.unwrap_or(0.9994999766349792f32) + } + + pub fn clear_character_coverage(&mut self) { + self.character_coverage = ::std::option::Option::None; + } + + pub fn has_character_coverage(&self) -> bool { + self.character_coverage.is_some() + } + + // Param is passed by value, moved + pub fn set_character_coverage(&mut self, v: f32) { + self.character_coverage = ::std::option::Option::Some(v); + } + + // optional uint64 input_sentence_size = 11; + + pub fn input_sentence_size(&self) -> u64 { + self.input_sentence_size.unwrap_or(0u64) + } + + pub fn clear_input_sentence_size(&mut self) { + self.input_sentence_size = ::std::option::Option::None; + } + + pub fn has_input_sentence_size(&self) -> bool { + self.input_sentence_size.is_some() + } + + // Param is passed by value, moved + pub fn set_input_sentence_size(&mut self, v: u64) { + self.input_sentence_size = ::std::option::Option::Some(v); + } + + // optional bool shuffle_input_sentence = 19; + + pub fn shuffle_input_sentence(&self) -> bool { + self.shuffle_input_sentence.unwrap_or(true) + } + + pub fn clear_shuffle_input_sentence(&mut self) { + self.shuffle_input_sentence = ::std::option::Option::None; + } + + pub fn has_shuffle_input_sentence(&self) -> bool { + self.shuffle_input_sentence.is_some() + } + + // Param is passed by value, moved + pub fn set_shuffle_input_sentence(&mut self, v: bool) { + self.shuffle_input_sentence = ::std::option::Option::Some(v); + } + + // optional int32 mining_sentence_size = 12; + + pub fn mining_sentence_size(&self) -> i32 { + self.mining_sentence_size.unwrap_or(0) + } + + pub fn clear_mining_sentence_size(&mut self) { + self.mining_sentence_size = ::std::option::Option::None; + } + + pub fn has_mining_sentence_size(&self) -> bool { + self.mining_sentence_size.is_some() + } + + // Param is passed by value, moved + pub fn set_mining_sentence_size(&mut self, v: i32) { + self.mining_sentence_size = ::std::option::Option::Some(v); + } + + // optional int32 training_sentence_size = 13; + + pub fn training_sentence_size(&self) -> i32 { + self.training_sentence_size.unwrap_or(0) + } + + pub fn clear_training_sentence_size(&mut self) { + self.training_sentence_size = ::std::option::Option::None; + } + + pub fn has_training_sentence_size(&self) -> bool { + self.training_sentence_size.is_some() + } + + // Param is passed by value, moved + pub fn set_training_sentence_size(&mut self, v: i32) { + self.training_sentence_size = ::std::option::Option::Some(v); + } + + // optional int32 seed_sentencepiece_size = 14; + + pub fn seed_sentencepiece_size(&self) -> i32 { + self.seed_sentencepiece_size.unwrap_or(1000000i32) + } + + pub fn clear_seed_sentencepiece_size(&mut self) { + self.seed_sentencepiece_size = ::std::option::Option::None; + } + + pub fn has_seed_sentencepiece_size(&self) -> bool { + self.seed_sentencepiece_size.is_some() + } + + // Param is passed by value, moved + pub fn set_seed_sentencepiece_size(&mut self, v: i32) { + self.seed_sentencepiece_size = ::std::option::Option::Some(v); + } + + // optional float shrinking_factor = 15; + + pub fn shrinking_factor(&self) -> f32 { + self.shrinking_factor.unwrap_or(0.75f32) + } + + pub fn clear_shrinking_factor(&mut self) { + self.shrinking_factor = ::std::option::Option::None; + } + + pub fn has_shrinking_factor(&self) -> bool { + self.shrinking_factor.is_some() + } + + // Param is passed by value, moved + pub fn set_shrinking_factor(&mut self, v: f32) { + self.shrinking_factor = ::std::option::Option::Some(v); + } + + // optional int32 max_sentence_length = 18; + + pub fn max_sentence_length(&self) -> i32 { + self.max_sentence_length.unwrap_or(4192i32) + } + + pub fn clear_max_sentence_length(&mut self) { + self.max_sentence_length = ::std::option::Option::None; + } + + pub fn has_max_sentence_length(&self) -> bool { + self.max_sentence_length.is_some() + } + + // Param is passed by value, moved + pub fn set_max_sentence_length(&mut self, v: i32) { + self.max_sentence_length = ::std::option::Option::Some(v); + } + + // optional int32 num_threads = 16; + + pub fn num_threads(&self) -> i32 { + self.num_threads.unwrap_or(16i32) + } + + pub fn clear_num_threads(&mut self) { + self.num_threads = ::std::option::Option::None; + } + + pub fn has_num_threads(&self) -> bool { + self.num_threads.is_some() + } + + // Param is passed by value, moved + pub fn set_num_threads(&mut self, v: i32) { + self.num_threads = ::std::option::Option::Some(v); + } + + // optional int32 num_sub_iterations = 17; + + pub fn num_sub_iterations(&self) -> i32 { + self.num_sub_iterations.unwrap_or(2i32) + } + + pub fn clear_num_sub_iterations(&mut self) { + self.num_sub_iterations = ::std::option::Option::None; + } + + pub fn has_num_sub_iterations(&self) -> bool { + self.num_sub_iterations.is_some() + } + + // Param is passed by value, moved + pub fn set_num_sub_iterations(&mut self, v: i32) { + self.num_sub_iterations = ::std::option::Option::Some(v); + } + + // optional int32 max_sentencepiece_length = 20; + + pub fn max_sentencepiece_length(&self) -> i32 { + self.max_sentencepiece_length.unwrap_or(16i32) + } + + pub fn clear_max_sentencepiece_length(&mut self) { + self.max_sentencepiece_length = ::std::option::Option::None; + } + + pub fn has_max_sentencepiece_length(&self) -> bool { + self.max_sentencepiece_length.is_some() + } + + // Param is passed by value, moved + pub fn set_max_sentencepiece_length(&mut self, v: i32) { + self.max_sentencepiece_length = ::std::option::Option::Some(v); + } + + // optional bool split_by_unicode_script = 21; + + pub fn split_by_unicode_script(&self) -> bool { + self.split_by_unicode_script.unwrap_or(true) + } + + pub fn clear_split_by_unicode_script(&mut self) { + self.split_by_unicode_script = ::std::option::Option::None; + } + + pub fn has_split_by_unicode_script(&self) -> bool { + self.split_by_unicode_script.is_some() + } + + // Param is passed by value, moved + pub fn set_split_by_unicode_script(&mut self, v: bool) { + self.split_by_unicode_script = ::std::option::Option::Some(v); + } + + // optional bool split_by_number = 23; + + pub fn split_by_number(&self) -> bool { + self.split_by_number.unwrap_or(true) + } + + pub fn clear_split_by_number(&mut self) { + self.split_by_number = ::std::option::Option::None; + } + + pub fn has_split_by_number(&self) -> bool { + self.split_by_number.is_some() + } + + // Param is passed by value, moved + pub fn set_split_by_number(&mut self, v: bool) { + self.split_by_number = ::std::option::Option::Some(v); + } + + // optional bool split_by_whitespace = 22; + + pub fn split_by_whitespace(&self) -> bool { + self.split_by_whitespace.unwrap_or(true) + } + + pub fn clear_split_by_whitespace(&mut self) { + self.split_by_whitespace = ::std::option::Option::None; + } + + pub fn has_split_by_whitespace(&self) -> bool { + self.split_by_whitespace.is_some() + } + + // Param is passed by value, moved + pub fn set_split_by_whitespace(&mut self, v: bool) { + self.split_by_whitespace = ::std::option::Option::Some(v); + } + + // optional bool treat_whitespace_as_suffix = 24; + + pub fn treat_whitespace_as_suffix(&self) -> bool { + self.treat_whitespace_as_suffix.unwrap_or(false) + } + + pub fn clear_treat_whitespace_as_suffix(&mut self) { + self.treat_whitespace_as_suffix = ::std::option::Option::None; + } + + pub fn has_treat_whitespace_as_suffix(&self) -> bool { + self.treat_whitespace_as_suffix.is_some() + } + + // Param is passed by value, moved + pub fn set_treat_whitespace_as_suffix(&mut self, v: bool) { + self.treat_whitespace_as_suffix = ::std::option::Option::Some(v); + } + + // optional bool allow_whitespace_only_pieces = 26; + + pub fn allow_whitespace_only_pieces(&self) -> bool { + self.allow_whitespace_only_pieces.unwrap_or(false) + } + + pub fn clear_allow_whitespace_only_pieces(&mut self) { + self.allow_whitespace_only_pieces = ::std::option::Option::None; + } + + pub fn has_allow_whitespace_only_pieces(&self) -> bool { + self.allow_whitespace_only_pieces.is_some() + } + + // Param is passed by value, moved + pub fn set_allow_whitespace_only_pieces(&mut self, v: bool) { + self.allow_whitespace_only_pieces = ::std::option::Option::Some(v); + } + + // optional bool split_digits = 25; + + pub fn split_digits(&self) -> bool { + self.split_digits.unwrap_or(false) + } + + pub fn clear_split_digits(&mut self) { + self.split_digits = ::std::option::Option::None; + } + + pub fn has_split_digits(&self) -> bool { + self.split_digits.is_some() + } + + // Param is passed by value, moved + pub fn set_split_digits(&mut self, v: bool) { + self.split_digits = ::std::option::Option::Some(v); + } + + // optional string required_chars = 36; + + pub fn required_chars(&self) -> &str { + match self.required_chars.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_required_chars(&mut self) { + self.required_chars = ::std::option::Option::None; + } + + pub fn has_required_chars(&self) -> bool { + self.required_chars.is_some() + } + + // Param is passed by value, moved + pub fn set_required_chars(&mut self, v: ::std::string::String) { + self.required_chars = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_required_chars(&mut self) -> &mut ::std::string::String { + if self.required_chars.is_none() { + self.required_chars = ::std::option::Option::Some(::std::string::String::new()); + } + self.required_chars.as_mut().unwrap() + } + + // Take field + pub fn take_required_chars(&mut self) -> ::std::string::String { + self.required_chars.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional bool byte_fallback = 35; + + pub fn byte_fallback(&self) -> bool { + self.byte_fallback.unwrap_or(false) + } + + pub fn clear_byte_fallback(&mut self) { + self.byte_fallback = ::std::option::Option::None; + } + + pub fn has_byte_fallback(&self) -> bool { + self.byte_fallback.is_some() + } + + // Param is passed by value, moved + pub fn set_byte_fallback(&mut self, v: bool) { + self.byte_fallback = ::std::option::Option::Some(v); + } + + // optional bool vocabulary_output_piece_score = 32; + + pub fn vocabulary_output_piece_score(&self) -> bool { + self.vocabulary_output_piece_score.unwrap_or(true) + } + + pub fn clear_vocabulary_output_piece_score(&mut self) { + self.vocabulary_output_piece_score = ::std::option::Option::None; + } + + pub fn has_vocabulary_output_piece_score(&self) -> bool { + self.vocabulary_output_piece_score.is_some() + } + + // Param is passed by value, moved + pub fn set_vocabulary_output_piece_score(&mut self, v: bool) { + self.vocabulary_output_piece_score = ::std::option::Option::Some(v); + } + + // optional bool hard_vocab_limit = 33; + + pub fn hard_vocab_limit(&self) -> bool { + self.hard_vocab_limit.unwrap_or(true) + } + + pub fn clear_hard_vocab_limit(&mut self) { + self.hard_vocab_limit = ::std::option::Option::None; + } + + pub fn has_hard_vocab_limit(&self) -> bool { + self.hard_vocab_limit.is_some() + } + + // Param is passed by value, moved + pub fn set_hard_vocab_limit(&mut self, v: bool) { + self.hard_vocab_limit = ::std::option::Option::Some(v); + } + + // optional bool use_all_vocab = 34; + + pub fn use_all_vocab(&self) -> bool { + self.use_all_vocab.unwrap_or(false) + } + + pub fn clear_use_all_vocab(&mut self) { + self.use_all_vocab = ::std::option::Option::None; + } + + pub fn has_use_all_vocab(&self) -> bool { + self.use_all_vocab.is_some() + } + + // Param is passed by value, moved + pub fn set_use_all_vocab(&mut self, v: bool) { + self.use_all_vocab = ::std::option::Option::Some(v); + } + + // optional int32 unk_id = 40; + + pub fn unk_id(&self) -> i32 { + self.unk_id.unwrap_or(0i32) + } + + pub fn clear_unk_id(&mut self) { + self.unk_id = ::std::option::Option::None; + } + + pub fn has_unk_id(&self) -> bool { + self.unk_id.is_some() + } + + // Param is passed by value, moved + pub fn set_unk_id(&mut self, v: i32) { + self.unk_id = ::std::option::Option::Some(v); + } + + // optional int32 bos_id = 41; + + pub fn bos_id(&self) -> i32 { + self.bos_id.unwrap_or(1i32) + } + + pub fn clear_bos_id(&mut self) { + self.bos_id = ::std::option::Option::None; + } + + pub fn has_bos_id(&self) -> bool { + self.bos_id.is_some() + } + + // Param is passed by value, moved + pub fn set_bos_id(&mut self, v: i32) { + self.bos_id = ::std::option::Option::Some(v); + } + + // optional int32 eos_id = 42; + + pub fn eos_id(&self) -> i32 { + self.eos_id.unwrap_or(2i32) + } + + pub fn clear_eos_id(&mut self) { + self.eos_id = ::std::option::Option::None; + } + + pub fn has_eos_id(&self) -> bool { + self.eos_id.is_some() + } + + // Param is passed by value, moved + pub fn set_eos_id(&mut self, v: i32) { + self.eos_id = ::std::option::Option::Some(v); + } + + // optional int32 pad_id = 43; + + pub fn pad_id(&self) -> i32 { + self.pad_id.unwrap_or(-1i32) + } + + pub fn clear_pad_id(&mut self) { + self.pad_id = ::std::option::Option::None; + } + + pub fn has_pad_id(&self) -> bool { + self.pad_id.is_some() + } + + // Param is passed by value, moved + pub fn set_pad_id(&mut self, v: i32) { + self.pad_id = ::std::option::Option::Some(v); + } + + // optional string unk_piece = 45; + + pub fn unk_piece(&self) -> &str { + match self.unk_piece.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_unk_piece(&mut self) { + self.unk_piece = ::std::option::Option::None; + } + + pub fn has_unk_piece(&self) -> bool { + self.unk_piece.is_some() + } + + // Param is passed by value, moved + pub fn set_unk_piece(&mut self, v: ::std::string::String) { + self.unk_piece = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_unk_piece(&mut self) -> &mut ::std::string::String { + if self.unk_piece.is_none() { + self.unk_piece = ::std::option::Option::Some(::std::string::String::new()); + } + self.unk_piece.as_mut().unwrap() + } + + // Take field + pub fn take_unk_piece(&mut self) -> ::std::string::String { + self.unk_piece.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional string bos_piece = 46; + + pub fn bos_piece(&self) -> &str { + match self.bos_piece.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_bos_piece(&mut self) { + self.bos_piece = ::std::option::Option::None; + } + + pub fn has_bos_piece(&self) -> bool { + self.bos_piece.is_some() + } + + // Param is passed by value, moved + pub fn set_bos_piece(&mut self, v: ::std::string::String) { + self.bos_piece = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_bos_piece(&mut self) -> &mut ::std::string::String { + if self.bos_piece.is_none() { + self.bos_piece = ::std::option::Option::Some(::std::string::String::new()); + } + self.bos_piece.as_mut().unwrap() + } + + // Take field + pub fn take_bos_piece(&mut self) -> ::std::string::String { + self.bos_piece.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional string eos_piece = 47; + + pub fn eos_piece(&self) -> &str { + match self.eos_piece.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_eos_piece(&mut self) { + self.eos_piece = ::std::option::Option::None; + } + + pub fn has_eos_piece(&self) -> bool { + self.eos_piece.is_some() + } + + // Param is passed by value, moved + pub fn set_eos_piece(&mut self, v: ::std::string::String) { + self.eos_piece = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_eos_piece(&mut self) -> &mut ::std::string::String { + if self.eos_piece.is_none() { + self.eos_piece = ::std::option::Option::Some(::std::string::String::new()); + } + self.eos_piece.as_mut().unwrap() + } + + // Take field + pub fn take_eos_piece(&mut self) -> ::std::string::String { + self.eos_piece.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional string pad_piece = 48; + + pub fn pad_piece(&self) -> &str { + match self.pad_piece.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_pad_piece(&mut self) { + self.pad_piece = ::std::option::Option::None; + } + + pub fn has_pad_piece(&self) -> bool { + self.pad_piece.is_some() + } + + // Param is passed by value, moved + pub fn set_pad_piece(&mut self, v: ::std::string::String) { + self.pad_piece = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_pad_piece(&mut self) -> &mut ::std::string::String { + if self.pad_piece.is_none() { + self.pad_piece = ::std::option::Option::Some(::std::string::String::new()); + } + self.pad_piece.as_mut().unwrap() + } + + // Take field + pub fn take_pad_piece(&mut self) -> ::std::string::String { + self.pad_piece.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional string unk_surface = 44; + + pub fn unk_surface(&self) -> &str { + match self.unk_surface.as_ref() { + Some(v) => v, + None => " \u{2047} ", + } + } + + pub fn clear_unk_surface(&mut self) { + self.unk_surface = ::std::option::Option::None; + } + + pub fn has_unk_surface(&self) -> bool { + self.unk_surface.is_some() + } + + // Param is passed by value, moved + pub fn set_unk_surface(&mut self, v: ::std::string::String) { + self.unk_surface = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_unk_surface(&mut self) -> &mut ::std::string::String { + if self.unk_surface.is_none() { + self.unk_surface = ::std::option::Option::Some(::std::string::String::new()); + } + self.unk_surface.as_mut().unwrap() + } + + // Take field + pub fn take_unk_surface(&mut self) -> ::std::string::String { + self.unk_surface.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional bool train_extremely_large_corpus = 49; + + pub fn train_extremely_large_corpus(&self) -> bool { + self.train_extremely_large_corpus.unwrap_or(false) + } + + pub fn clear_train_extremely_large_corpus(&mut self) { + self.train_extremely_large_corpus = ::std::option::Option::None; + } + + pub fn has_train_extremely_large_corpus(&self) -> bool { + self.train_extremely_large_corpus.is_some() + } + + // Param is passed by value, moved + pub fn set_train_extremely_large_corpus(&mut self, v: bool) { + self.train_extremely_large_corpus = ::std::option::Option::Some(v); + } +} + +impl ::protobuf::Message for TrainerSpec { + const NAME: &'static str = "TrainerSpec"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.input.push(is.read_string()?); + }, + 58 => { + self.input_format = ::std::option::Option::Some(is.read_string()?); + }, + 18 => { + self.model_prefix = ::std::option::Option::Some(is.read_string()?); + }, + 24 => { + self.model_type = ::std::option::Option::Some(is.read_enum_or_unknown()?); + }, + 32 => { + self.vocab_size = ::std::option::Option::Some(is.read_int32()?); + }, + 42 => { + self.accept_language.push(is.read_string()?); + }, + 48 => { + self.self_test_sample_size = ::std::option::Option::Some(is.read_int32()?); + }, + 400 => { + self.enable_differential_privacy = ::std::option::Option::Some(is.read_bool()?); + }, + 413 => { + self.differential_privacy_noise_level = ::std::option::Option::Some(is.read_float()?); + }, + 416 => { + self.differential_privacy_clipping_threshold = ::std::option::Option::Some(is.read_uint64()?); + }, + 85 => { + self.character_coverage = ::std::option::Option::Some(is.read_float()?); + }, + 88 => { + self.input_sentence_size = ::std::option::Option::Some(is.read_uint64()?); + }, + 152 => { + self.shuffle_input_sentence = ::std::option::Option::Some(is.read_bool()?); + }, + 96 => { + self.mining_sentence_size = ::std::option::Option::Some(is.read_int32()?); + }, + 104 => { + self.training_sentence_size = ::std::option::Option::Some(is.read_int32()?); + }, + 112 => { + self.seed_sentencepiece_size = ::std::option::Option::Some(is.read_int32()?); + }, + 125 => { + self.shrinking_factor = ::std::option::Option::Some(is.read_float()?); + }, + 144 => { + self.max_sentence_length = ::std::option::Option::Some(is.read_int32()?); + }, + 128 => { + self.num_threads = ::std::option::Option::Some(is.read_int32()?); + }, + 136 => { + self.num_sub_iterations = ::std::option::Option::Some(is.read_int32()?); + }, + 160 => { + self.max_sentencepiece_length = ::std::option::Option::Some(is.read_int32()?); + }, + 168 => { + self.split_by_unicode_script = ::std::option::Option::Some(is.read_bool()?); + }, + 184 => { + self.split_by_number = ::std::option::Option::Some(is.read_bool()?); + }, + 176 => { + self.split_by_whitespace = ::std::option::Option::Some(is.read_bool()?); + }, + 192 => { + self.treat_whitespace_as_suffix = ::std::option::Option::Some(is.read_bool()?); + }, + 208 => { + self.allow_whitespace_only_pieces = ::std::option::Option::Some(is.read_bool()?); + }, + 200 => { + self.split_digits = ::std::option::Option::Some(is.read_bool()?); + }, + 242 => { + self.control_symbols.push(is.read_string()?); + }, + 250 => { + self.user_defined_symbols.push(is.read_string()?); + }, + 290 => { + self.required_chars = ::std::option::Option::Some(is.read_string()?); + }, + 280 => { + self.byte_fallback = ::std::option::Option::Some(is.read_bool()?); + }, + 256 => { + self.vocabulary_output_piece_score = ::std::option::Option::Some(is.read_bool()?); + }, + 264 => { + self.hard_vocab_limit = ::std::option::Option::Some(is.read_bool()?); + }, + 272 => { + self.use_all_vocab = ::std::option::Option::Some(is.read_bool()?); + }, + 320 => { + self.unk_id = ::std::option::Option::Some(is.read_int32()?); + }, + 328 => { + self.bos_id = ::std::option::Option::Some(is.read_int32()?); + }, + 336 => { + self.eos_id = ::std::option::Option::Some(is.read_int32()?); + }, + 344 => { + self.pad_id = ::std::option::Option::Some(is.read_int32()?); + }, + 362 => { + self.unk_piece = ::std::option::Option::Some(is.read_string()?); + }, + 370 => { + self.bos_piece = ::std::option::Option::Some(is.read_string()?); + }, + 378 => { + self.eos_piece = ::std::option::Option::Some(is.read_string()?); + }, + 386 => { + self.pad_piece = ::std::option::Option::Some(is.read_string()?); + }, + 354 => { + self.unk_surface = ::std::option::Option::Some(is.read_string()?); + }, + 392 => { + self.train_extremely_large_corpus = ::std::option::Option::Some(is.read_bool()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + for value in &self.input { + my_size += ::protobuf::rt::string_size(1, &value); + }; + if let Some(v) = self.input_format.as_ref() { + my_size += ::protobuf::rt::string_size(7, &v); + } + if let Some(v) = self.model_prefix.as_ref() { + my_size += ::protobuf::rt::string_size(2, &v); + } + if let Some(v) = self.model_type { + my_size += ::protobuf::rt::int32_size(3, v.value()); + } + if let Some(v) = self.vocab_size { + my_size += ::protobuf::rt::int32_size(4, v); + } + for value in &self.accept_language { + my_size += ::protobuf::rt::string_size(5, &value); + }; + if let Some(v) = self.self_test_sample_size { + my_size += ::protobuf::rt::int32_size(6, v); + } + if let Some(v) = self.enable_differential_privacy { + my_size += 2 + 1; + } + if let Some(v) = self.differential_privacy_noise_level { + my_size += 2 + 4; + } + if let Some(v) = self.differential_privacy_clipping_threshold { + my_size += ::protobuf::rt::uint64_size(52, v); + } + if let Some(v) = self.character_coverage { + my_size += 1 + 4; + } + if let Some(v) = self.input_sentence_size { + my_size += ::protobuf::rt::uint64_size(11, v); + } + if let Some(v) = self.shuffle_input_sentence { + my_size += 2 + 1; + } + if let Some(v) = self.mining_sentence_size { + my_size += ::protobuf::rt::int32_size(12, v); + } + if let Some(v) = self.training_sentence_size { + my_size += ::protobuf::rt::int32_size(13, v); + } + if let Some(v) = self.seed_sentencepiece_size { + my_size += ::protobuf::rt::int32_size(14, v); + } + if let Some(v) = self.shrinking_factor { + my_size += 1 + 4; + } + if let Some(v) = self.max_sentence_length { + my_size += ::protobuf::rt::int32_size(18, v); + } + if let Some(v) = self.num_threads { + my_size += ::protobuf::rt::int32_size(16, v); + } + if let Some(v) = self.num_sub_iterations { + my_size += ::protobuf::rt::int32_size(17, v); + } + if let Some(v) = self.max_sentencepiece_length { + my_size += ::protobuf::rt::int32_size(20, v); + } + if let Some(v) = self.split_by_unicode_script { + my_size += 2 + 1; + } + if let Some(v) = self.split_by_number { + my_size += 2 + 1; + } + if let Some(v) = self.split_by_whitespace { + my_size += 2 + 1; + } + if let Some(v) = self.treat_whitespace_as_suffix { + my_size += 2 + 1; + } + if let Some(v) = self.allow_whitespace_only_pieces { + my_size += 2 + 1; + } + if let Some(v) = self.split_digits { + my_size += 2 + 1; + } + for value in &self.control_symbols { + my_size += ::protobuf::rt::string_size(30, &value); + }; + for value in &self.user_defined_symbols { + my_size += ::protobuf::rt::string_size(31, &value); + }; + if let Some(v) = self.required_chars.as_ref() { + my_size += ::protobuf::rt::string_size(36, &v); + } + if let Some(v) = self.byte_fallback { + my_size += 2 + 1; + } + if let Some(v) = self.vocabulary_output_piece_score { + my_size += 2 + 1; + } + if let Some(v) = self.hard_vocab_limit { + my_size += 2 + 1; + } + if let Some(v) = self.use_all_vocab { + my_size += 2 + 1; + } + if let Some(v) = self.unk_id { + my_size += ::protobuf::rt::int32_size(40, v); + } + if let Some(v) = self.bos_id { + my_size += ::protobuf::rt::int32_size(41, v); + } + if let Some(v) = self.eos_id { + my_size += ::protobuf::rt::int32_size(42, v); + } + if let Some(v) = self.pad_id { + my_size += ::protobuf::rt::int32_size(43, v); + } + if let Some(v) = self.unk_piece.as_ref() { + my_size += ::protobuf::rt::string_size(45, &v); + } + if let Some(v) = self.bos_piece.as_ref() { + my_size += ::protobuf::rt::string_size(46, &v); + } + if let Some(v) = self.eos_piece.as_ref() { + my_size += ::protobuf::rt::string_size(47, &v); + } + if let Some(v) = self.pad_piece.as_ref() { + my_size += ::protobuf::rt::string_size(48, &v); + } + if let Some(v) = self.unk_surface.as_ref() { + my_size += ::protobuf::rt::string_size(44, &v); + } + if let Some(v) = self.train_extremely_large_corpus { + my_size += 2 + 1; + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + for v in &self.input { + os.write_string(1, &v)?; + }; + if let Some(v) = self.input_format.as_ref() { + os.write_string(7, v)?; + } + if let Some(v) = self.model_prefix.as_ref() { + os.write_string(2, v)?; + } + if let Some(v) = self.model_type { + os.write_enum(3, ::protobuf::EnumOrUnknown::value(&v))?; + } + if let Some(v) = self.vocab_size { + os.write_int32(4, v)?; + } + for v in &self.accept_language { + os.write_string(5, &v)?; + }; + if let Some(v) = self.self_test_sample_size { + os.write_int32(6, v)?; + } + if let Some(v) = self.enable_differential_privacy { + os.write_bool(50, v)?; + } + if let Some(v) = self.differential_privacy_noise_level { + os.write_float(51, v)?; + } + if let Some(v) = self.differential_privacy_clipping_threshold { + os.write_uint64(52, v)?; + } + if let Some(v) = self.character_coverage { + os.write_float(10, v)?; + } + if let Some(v) = self.input_sentence_size { + os.write_uint64(11, v)?; + } + if let Some(v) = self.shuffle_input_sentence { + os.write_bool(19, v)?; + } + if let Some(v) = self.mining_sentence_size { + os.write_int32(12, v)?; + } + if let Some(v) = self.training_sentence_size { + os.write_int32(13, v)?; + } + if let Some(v) = self.seed_sentencepiece_size { + os.write_int32(14, v)?; + } + if let Some(v) = self.shrinking_factor { + os.write_float(15, v)?; + } + if let Some(v) = self.max_sentence_length { + os.write_int32(18, v)?; + } + if let Some(v) = self.num_threads { + os.write_int32(16, v)?; + } + if let Some(v) = self.num_sub_iterations { + os.write_int32(17, v)?; + } + if let Some(v) = self.max_sentencepiece_length { + os.write_int32(20, v)?; + } + if let Some(v) = self.split_by_unicode_script { + os.write_bool(21, v)?; + } + if let Some(v) = self.split_by_number { + os.write_bool(23, v)?; + } + if let Some(v) = self.split_by_whitespace { + os.write_bool(22, v)?; + } + if let Some(v) = self.treat_whitespace_as_suffix { + os.write_bool(24, v)?; + } + if let Some(v) = self.allow_whitespace_only_pieces { + os.write_bool(26, v)?; + } + if let Some(v) = self.split_digits { + os.write_bool(25, v)?; + } + for v in &self.control_symbols { + os.write_string(30, &v)?; + }; + for v in &self.user_defined_symbols { + os.write_string(31, &v)?; + }; + if let Some(v) = self.required_chars.as_ref() { + os.write_string(36, v)?; + } + if let Some(v) = self.byte_fallback { + os.write_bool(35, v)?; + } + if let Some(v) = self.vocabulary_output_piece_score { + os.write_bool(32, v)?; + } + if let Some(v) = self.hard_vocab_limit { + os.write_bool(33, v)?; + } + if let Some(v) = self.use_all_vocab { + os.write_bool(34, v)?; + } + if let Some(v) = self.unk_id { + os.write_int32(40, v)?; + } + if let Some(v) = self.bos_id { + os.write_int32(41, v)?; + } + if let Some(v) = self.eos_id { + os.write_int32(42, v)?; + } + if let Some(v) = self.pad_id { + os.write_int32(43, v)?; + } + if let Some(v) = self.unk_piece.as_ref() { + os.write_string(45, v)?; + } + if let Some(v) = self.bos_piece.as_ref() { + os.write_string(46, v)?; + } + if let Some(v) = self.eos_piece.as_ref() { + os.write_string(47, v)?; + } + if let Some(v) = self.pad_piece.as_ref() { + os.write_string(48, v)?; + } + if let Some(v) = self.unk_surface.as_ref() { + os.write_string(44, v)?; + } + if let Some(v) = self.train_extremely_large_corpus { + os.write_bool(49, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> TrainerSpec { + TrainerSpec::new() + } + + fn clear(&mut self) { + self.input.clear(); + self.input_format = ::std::option::Option::None; + self.model_prefix = ::std::option::Option::None; + self.model_type = ::std::option::Option::None; + self.vocab_size = ::std::option::Option::None; + self.accept_language.clear(); + self.self_test_sample_size = ::std::option::Option::None; + self.enable_differential_privacy = ::std::option::Option::None; + self.differential_privacy_noise_level = ::std::option::Option::None; + self.differential_privacy_clipping_threshold = ::std::option::Option::None; + self.character_coverage = ::std::option::Option::None; + self.input_sentence_size = ::std::option::Option::None; + self.shuffle_input_sentence = ::std::option::Option::None; + self.mining_sentence_size = ::std::option::Option::None; + self.training_sentence_size = ::std::option::Option::None; + self.seed_sentencepiece_size = ::std::option::Option::None; + self.shrinking_factor = ::std::option::Option::None; + self.max_sentence_length = ::std::option::Option::None; + self.num_threads = ::std::option::Option::None; + self.num_sub_iterations = ::std::option::Option::None; + self.max_sentencepiece_length = ::std::option::Option::None; + self.split_by_unicode_script = ::std::option::Option::None; + self.split_by_number = ::std::option::Option::None; + self.split_by_whitespace = ::std::option::Option::None; + self.treat_whitespace_as_suffix = ::std::option::Option::None; + self.allow_whitespace_only_pieces = ::std::option::Option::None; + self.split_digits = ::std::option::Option::None; + self.control_symbols.clear(); + self.user_defined_symbols.clear(); + self.required_chars = ::std::option::Option::None; + self.byte_fallback = ::std::option::Option::None; + self.vocabulary_output_piece_score = ::std::option::Option::None; + self.hard_vocab_limit = ::std::option::Option::None; + self.use_all_vocab = ::std::option::Option::None; + self.unk_id = ::std::option::Option::None; + self.bos_id = ::std::option::Option::None; + self.eos_id = ::std::option::Option::None; + self.pad_id = ::std::option::Option::None; + self.unk_piece = ::std::option::Option::None; + self.bos_piece = ::std::option::Option::None; + self.eos_piece = ::std::option::Option::None; + self.pad_piece = ::std::option::Option::None; + self.unk_surface = ::std::option::Option::None; + self.train_extremely_large_corpus = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static TrainerSpec { + static instance: TrainerSpec = TrainerSpec { + input: ::std::vec::Vec::new(), + input_format: ::std::option::Option::None, + model_prefix: ::std::option::Option::None, + model_type: ::std::option::Option::None, + vocab_size: ::std::option::Option::None, + accept_language: ::std::vec::Vec::new(), + self_test_sample_size: ::std::option::Option::None, + enable_differential_privacy: ::std::option::Option::None, + differential_privacy_noise_level: ::std::option::Option::None, + differential_privacy_clipping_threshold: ::std::option::Option::None, + character_coverage: ::std::option::Option::None, + input_sentence_size: ::std::option::Option::None, + shuffle_input_sentence: ::std::option::Option::None, + mining_sentence_size: ::std::option::Option::None, + training_sentence_size: ::std::option::Option::None, + seed_sentencepiece_size: ::std::option::Option::None, + shrinking_factor: ::std::option::Option::None, + max_sentence_length: ::std::option::Option::None, + num_threads: ::std::option::Option::None, + num_sub_iterations: ::std::option::Option::None, + max_sentencepiece_length: ::std::option::Option::None, + split_by_unicode_script: ::std::option::Option::None, + split_by_number: ::std::option::Option::None, + split_by_whitespace: ::std::option::Option::None, + treat_whitespace_as_suffix: ::std::option::Option::None, + allow_whitespace_only_pieces: ::std::option::Option::None, + split_digits: ::std::option::Option::None, + control_symbols: ::std::vec::Vec::new(), + user_defined_symbols: ::std::vec::Vec::new(), + required_chars: ::std::option::Option::None, + byte_fallback: ::std::option::Option::None, + vocabulary_output_piece_score: ::std::option::Option::None, + hard_vocab_limit: ::std::option::Option::None, + use_all_vocab: ::std::option::Option::None, + unk_id: ::std::option::Option::None, + bos_id: ::std::option::Option::None, + eos_id: ::std::option::Option::None, + pad_id: ::std::option::Option::None, + unk_piece: ::std::option::Option::None, + bos_piece: ::std::option::Option::None, + eos_piece: ::std::option::Option::None, + pad_piece: ::std::option::Option::None, + unk_surface: ::std::option::Option::None, + train_extremely_large_corpus: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +/// Nested message and enums of message `TrainerSpec` +pub mod trainer_spec { + #[derive(Clone,Copy,PartialEq,Eq,Debug,Hash)] + // @@protoc_insertion_point(enum:sentencepiece.TrainerSpec.ModelType) + pub enum ModelType { + // @@protoc_insertion_point(enum_value:sentencepiece.TrainerSpec.ModelType.UNIGRAM) + UNIGRAM = 1, + // @@protoc_insertion_point(enum_value:sentencepiece.TrainerSpec.ModelType.BPE) + BPE = 2, + // @@protoc_insertion_point(enum_value:sentencepiece.TrainerSpec.ModelType.WORD) + WORD = 3, + // @@protoc_insertion_point(enum_value:sentencepiece.TrainerSpec.ModelType.CHAR) + CHAR = 4, + } + + impl ::protobuf::Enum for ModelType { + const NAME: &'static str = "ModelType"; + + fn value(&self) -> i32 { + *self as i32 + } + + fn from_i32(value: i32) -> ::std::option::Option { + match value { + 1 => ::std::option::Option::Some(ModelType::UNIGRAM), + 2 => ::std::option::Option::Some(ModelType::BPE), + 3 => ::std::option::Option::Some(ModelType::WORD), + 4 => ::std::option::Option::Some(ModelType::CHAR), + _ => ::std::option::Option::None + } + } + + const VALUES: &'static [ModelType] = &[ + ModelType::UNIGRAM, + ModelType::BPE, + ModelType::WORD, + ModelType::CHAR, + ]; + } + + // Note, `Default` is implemented although default value is not 0 + impl ::std::default::Default for ModelType { + fn default() -> Self { + ModelType::UNIGRAM + } + } + +} + +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:sentencepiece.NormalizerSpec) +pub struct NormalizerSpec { + // message fields + // @@protoc_insertion_point(field:sentencepiece.NormalizerSpec.name) + pub name: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.NormalizerSpec.precompiled_charsmap) + pub precompiled_charsmap: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:sentencepiece.NormalizerSpec.add_dummy_prefix) + pub add_dummy_prefix: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.NormalizerSpec.remove_extra_whitespaces) + pub remove_extra_whitespaces: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.NormalizerSpec.escape_whitespaces) + pub escape_whitespaces: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.NormalizerSpec.normalization_rule_tsv) + pub normalization_rule_tsv: ::std::option::Option<::std::string::String>, + // special fields + // @@protoc_insertion_point(special_field:sentencepiece.NormalizerSpec.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a NormalizerSpec { + fn default() -> &'a NormalizerSpec { + ::default_instance() + } +} + +impl NormalizerSpec { + pub fn new() -> NormalizerSpec { + ::std::default::Default::default() + } + + // optional string name = 1; + + pub fn name(&self) -> &str { + match self.name.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_name(&mut self) { + self.name = ::std::option::Option::None; + } + + pub fn has_name(&self) -> bool { + self.name.is_some() + } + + // Param is passed by value, moved + pub fn set_name(&mut self, v: ::std::string::String) { + self.name = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_name(&mut self) -> &mut ::std::string::String { + if self.name.is_none() { + self.name = ::std::option::Option::Some(::std::string::String::new()); + } + self.name.as_mut().unwrap() + } + + // Take field + pub fn take_name(&mut self) -> ::std::string::String { + self.name.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional bytes precompiled_charsmap = 2; + + pub fn precompiled_charsmap(&self) -> &[u8] { + match self.precompiled_charsmap.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_precompiled_charsmap(&mut self) { + self.precompiled_charsmap = ::std::option::Option::None; + } + + pub fn has_precompiled_charsmap(&self) -> bool { + self.precompiled_charsmap.is_some() + } + + // Param is passed by value, moved + pub fn set_precompiled_charsmap(&mut self, v: ::std::vec::Vec) { + self.precompiled_charsmap = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_precompiled_charsmap(&mut self) -> &mut ::std::vec::Vec { + if self.precompiled_charsmap.is_none() { + self.precompiled_charsmap = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.precompiled_charsmap.as_mut().unwrap() + } + + // Take field + pub fn take_precompiled_charsmap(&mut self) -> ::std::vec::Vec { + self.precompiled_charsmap.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bool add_dummy_prefix = 3; + + pub fn add_dummy_prefix(&self) -> bool { + self.add_dummy_prefix.unwrap_or(true) + } + + pub fn clear_add_dummy_prefix(&mut self) { + self.add_dummy_prefix = ::std::option::Option::None; + } + + pub fn has_add_dummy_prefix(&self) -> bool { + self.add_dummy_prefix.is_some() + } + + // Param is passed by value, moved + pub fn set_add_dummy_prefix(&mut self, v: bool) { + self.add_dummy_prefix = ::std::option::Option::Some(v); + } + + // optional bool remove_extra_whitespaces = 4; + + pub fn remove_extra_whitespaces(&self) -> bool { + self.remove_extra_whitespaces.unwrap_or(true) + } + + pub fn clear_remove_extra_whitespaces(&mut self) { + self.remove_extra_whitespaces = ::std::option::Option::None; + } + + pub fn has_remove_extra_whitespaces(&self) -> bool { + self.remove_extra_whitespaces.is_some() + } + + // Param is passed by value, moved + pub fn set_remove_extra_whitespaces(&mut self, v: bool) { + self.remove_extra_whitespaces = ::std::option::Option::Some(v); + } + + // optional bool escape_whitespaces = 5; + + pub fn escape_whitespaces(&self) -> bool { + self.escape_whitespaces.unwrap_or(true) + } + + pub fn clear_escape_whitespaces(&mut self) { + self.escape_whitespaces = ::std::option::Option::None; + } + + pub fn has_escape_whitespaces(&self) -> bool { + self.escape_whitespaces.is_some() + } + + // Param is passed by value, moved + pub fn set_escape_whitespaces(&mut self, v: bool) { + self.escape_whitespaces = ::std::option::Option::Some(v); + } + + // optional string normalization_rule_tsv = 6; + + pub fn normalization_rule_tsv(&self) -> &str { + match self.normalization_rule_tsv.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_normalization_rule_tsv(&mut self) { + self.normalization_rule_tsv = ::std::option::Option::None; + } + + pub fn has_normalization_rule_tsv(&self) -> bool { + self.normalization_rule_tsv.is_some() + } + + // Param is passed by value, moved + pub fn set_normalization_rule_tsv(&mut self, v: ::std::string::String) { + self.normalization_rule_tsv = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_normalization_rule_tsv(&mut self) -> &mut ::std::string::String { + if self.normalization_rule_tsv.is_none() { + self.normalization_rule_tsv = ::std::option::Option::Some(::std::string::String::new()); + } + self.normalization_rule_tsv.as_mut().unwrap() + } + + // Take field + pub fn take_normalization_rule_tsv(&mut self) -> ::std::string::String { + self.normalization_rule_tsv.take().unwrap_or_else(|| ::std::string::String::new()) + } +} + +impl ::protobuf::Message for NormalizerSpec { + const NAME: &'static str = "NormalizerSpec"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.name = ::std::option::Option::Some(is.read_string()?); + }, + 18 => { + self.precompiled_charsmap = ::std::option::Option::Some(is.read_bytes()?); + }, + 24 => { + self.add_dummy_prefix = ::std::option::Option::Some(is.read_bool()?); + }, + 32 => { + self.remove_extra_whitespaces = ::std::option::Option::Some(is.read_bool()?); + }, + 40 => { + self.escape_whitespaces = ::std::option::Option::Some(is.read_bool()?); + }, + 50 => { + self.normalization_rule_tsv = ::std::option::Option::Some(is.read_string()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.name.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.precompiled_charsmap.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + if let Some(v) = self.add_dummy_prefix { + my_size += 1 + 1; + } + if let Some(v) = self.remove_extra_whitespaces { + my_size += 1 + 1; + } + if let Some(v) = self.escape_whitespaces { + my_size += 1 + 1; + } + if let Some(v) = self.normalization_rule_tsv.as_ref() { + my_size += ::protobuf::rt::string_size(6, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.name.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.precompiled_charsmap.as_ref() { + os.write_bytes(2, v)?; + } + if let Some(v) = self.add_dummy_prefix { + os.write_bool(3, v)?; + } + if let Some(v) = self.remove_extra_whitespaces { + os.write_bool(4, v)?; + } + if let Some(v) = self.escape_whitespaces { + os.write_bool(5, v)?; + } + if let Some(v) = self.normalization_rule_tsv.as_ref() { + os.write_string(6, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> NormalizerSpec { + NormalizerSpec::new() + } + + fn clear(&mut self) { + self.name = ::std::option::Option::None; + self.precompiled_charsmap = ::std::option::Option::None; + self.add_dummy_prefix = ::std::option::Option::None; + self.remove_extra_whitespaces = ::std::option::Option::None; + self.escape_whitespaces = ::std::option::Option::None; + self.normalization_rule_tsv = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static NormalizerSpec { + static instance: NormalizerSpec = NormalizerSpec { + name: ::std::option::Option::None, + precompiled_charsmap: ::std::option::Option::None, + add_dummy_prefix: ::std::option::Option::None, + remove_extra_whitespaces: ::std::option::Option::None, + escape_whitespaces: ::std::option::Option::None, + normalization_rule_tsv: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:sentencepiece.SelfTestData) +pub struct SelfTestData { + // message fields + // @@protoc_insertion_point(field:sentencepiece.SelfTestData.samples) + pub samples: ::std::vec::Vec, + // special fields + // @@protoc_insertion_point(special_field:sentencepiece.SelfTestData.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a SelfTestData { + fn default() -> &'a SelfTestData { + ::default_instance() + } +} + +impl SelfTestData { + pub fn new() -> SelfTestData { + ::std::default::Default::default() + } +} + +impl ::protobuf::Message for SelfTestData { + const NAME: &'static str = "SelfTestData"; + + fn is_initialized(&self) -> bool { + for v in &self.samples { + if !v.is_initialized() { + return false; + } + }; + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.samples.push(is.read_message()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + for value in &self.samples { + let len = value.compute_size(); + my_size += 1 + ::protobuf::rt::compute_raw_varint64_size(len) + len; + }; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + for v in &self.samples { + ::protobuf::rt::write_message_field_with_cached_size(1, v, os)?; + }; + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> SelfTestData { + SelfTestData::new() + } + + fn clear(&mut self) { + self.samples.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static SelfTestData { + static instance: SelfTestData = SelfTestData { + samples: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +/// Nested message and enums of message `SelfTestData` +pub mod self_test_data { + #[derive(PartialEq,Clone,Default,Debug)] + // @@protoc_insertion_point(message:sentencepiece.SelfTestData.Sample) + pub struct Sample { + // message fields + // @@protoc_insertion_point(field:sentencepiece.SelfTestData.Sample.input) + pub input: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.SelfTestData.Sample.expected) + pub expected: ::std::option::Option<::std::string::String>, + // special fields + // @@protoc_insertion_point(special_field:sentencepiece.SelfTestData.Sample.special_fields) + pub special_fields: ::protobuf::SpecialFields, + } + + impl<'a> ::std::default::Default for &'a Sample { + fn default() -> &'a Sample { + ::default_instance() + } + } + + impl Sample { + pub fn new() -> Sample { + ::std::default::Default::default() + } + + // optional string input = 1; + + pub fn input(&self) -> &str { + match self.input.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_input(&mut self) { + self.input = ::std::option::Option::None; + } + + pub fn has_input(&self) -> bool { + self.input.is_some() + } + + // Param is passed by value, moved + pub fn set_input(&mut self, v: ::std::string::String) { + self.input = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_input(&mut self) -> &mut ::std::string::String { + if self.input.is_none() { + self.input = ::std::option::Option::Some(::std::string::String::new()); + } + self.input.as_mut().unwrap() + } + + // Take field + pub fn take_input(&mut self) -> ::std::string::String { + self.input.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional string expected = 2; + + pub fn expected(&self) -> &str { + match self.expected.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_expected(&mut self) { + self.expected = ::std::option::Option::None; + } + + pub fn has_expected(&self) -> bool { + self.expected.is_some() + } + + // Param is passed by value, moved + pub fn set_expected(&mut self, v: ::std::string::String) { + self.expected = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_expected(&mut self) -> &mut ::std::string::String { + if self.expected.is_none() { + self.expected = ::std::option::Option::Some(::std::string::String::new()); + } + self.expected.as_mut().unwrap() + } + + // Take field + pub fn take_expected(&mut self) -> ::std::string::String { + self.expected.take().unwrap_or_else(|| ::std::string::String::new()) + } + } + + impl ::protobuf::Message for Sample { + const NAME: &'static str = "Sample"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.input = ::std::option::Option::Some(is.read_string()?); + }, + 18 => { + self.expected = ::std::option::Option::Some(is.read_string()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.input.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.expected.as_ref() { + my_size += ::protobuf::rt::string_size(2, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.input.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.expected.as_ref() { + os.write_string(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> Sample { + Sample::new() + } + + fn clear(&mut self) { + self.input = ::std::option::Option::None; + self.expected = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static Sample { + static instance: Sample = Sample { + input: ::std::option::Option::None, + expected: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } + } +} + +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:sentencepiece.ModelProto) +pub struct ModelProto { + // message fields + // @@protoc_insertion_point(field:sentencepiece.ModelProto.pieces) + pub pieces: ::std::vec::Vec, + // @@protoc_insertion_point(field:sentencepiece.ModelProto.trainer_spec) + pub trainer_spec: ::protobuf::MessageField, + // @@protoc_insertion_point(field:sentencepiece.ModelProto.normalizer_spec) + pub normalizer_spec: ::protobuf::MessageField, + // @@protoc_insertion_point(field:sentencepiece.ModelProto.self_test_data) + pub self_test_data: ::protobuf::MessageField, + // @@protoc_insertion_point(field:sentencepiece.ModelProto.denormalizer_spec) + pub denormalizer_spec: ::protobuf::MessageField, + // special fields + // @@protoc_insertion_point(special_field:sentencepiece.ModelProto.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ModelProto { + fn default() -> &'a ModelProto { + ::default_instance() + } +} + +impl ModelProto { + pub fn new() -> ModelProto { + ::std::default::Default::default() + } +} + +impl ::protobuf::Message for ModelProto { + const NAME: &'static str = "ModelProto"; + + fn is_initialized(&self) -> bool { + for v in &self.pieces { + if !v.is_initialized() { + return false; + } + }; + for v in &self.trainer_spec { + if !v.is_initialized() { + return false; + } + }; + for v in &self.normalizer_spec { + if !v.is_initialized() { + return false; + } + }; + for v in &self.self_test_data { + if !v.is_initialized() { + return false; + } + }; + for v in &self.denormalizer_spec { + if !v.is_initialized() { + return false; + } + }; + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.pieces.push(is.read_message()?); + }, + 18 => { + ::protobuf::rt::read_singular_message_into_field(is, &mut self.trainer_spec)?; + }, + 26 => { + ::protobuf::rt::read_singular_message_into_field(is, &mut self.normalizer_spec)?; + }, + 34 => { + ::protobuf::rt::read_singular_message_into_field(is, &mut self.self_test_data)?; + }, + 42 => { + ::protobuf::rt::read_singular_message_into_field(is, &mut self.denormalizer_spec)?; + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + for value in &self.pieces { + let len = value.compute_size(); + my_size += 1 + ::protobuf::rt::compute_raw_varint64_size(len) + len; + }; + if let Some(v) = self.trainer_spec.as_ref() { + let len = v.compute_size(); + my_size += 1 + ::protobuf::rt::compute_raw_varint64_size(len) + len; + } + if let Some(v) = self.normalizer_spec.as_ref() { + let len = v.compute_size(); + my_size += 1 + ::protobuf::rt::compute_raw_varint64_size(len) + len; + } + if let Some(v) = self.self_test_data.as_ref() { + let len = v.compute_size(); + my_size += 1 + ::protobuf::rt::compute_raw_varint64_size(len) + len; + } + if let Some(v) = self.denormalizer_spec.as_ref() { + let len = v.compute_size(); + my_size += 1 + ::protobuf::rt::compute_raw_varint64_size(len) + len; + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + for v in &self.pieces { + ::protobuf::rt::write_message_field_with_cached_size(1, v, os)?; + }; + if let Some(v) = self.trainer_spec.as_ref() { + ::protobuf::rt::write_message_field_with_cached_size(2, v, os)?; + } + if let Some(v) = self.normalizer_spec.as_ref() { + ::protobuf::rt::write_message_field_with_cached_size(3, v, os)?; + } + if let Some(v) = self.self_test_data.as_ref() { + ::protobuf::rt::write_message_field_with_cached_size(4, v, os)?; + } + if let Some(v) = self.denormalizer_spec.as_ref() { + ::protobuf::rt::write_message_field_with_cached_size(5, v, os)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ModelProto { + ModelProto::new() + } + + fn clear(&mut self) { + self.pieces.clear(); + self.trainer_spec.clear(); + self.normalizer_spec.clear(); + self.self_test_data.clear(); + self.denormalizer_spec.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static ModelProto { + static instance: ModelProto = ModelProto { + pieces: ::std::vec::Vec::new(), + trainer_spec: ::protobuf::MessageField::none(), + normalizer_spec: ::protobuf::MessageField::none(), + self_test_data: ::protobuf::MessageField::none(), + denormalizer_spec: ::protobuf::MessageField::none(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +/// Nested message and enums of message `ModelProto` +pub mod model_proto { + #[derive(PartialEq,Clone,Default,Debug)] + // @@protoc_insertion_point(message:sentencepiece.ModelProto.SentencePiece) + pub struct SentencePiece { + // message fields + // @@protoc_insertion_point(field:sentencepiece.ModelProto.SentencePiece.piece) + pub piece: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:sentencepiece.ModelProto.SentencePiece.score) + pub score: ::std::option::Option, + // @@protoc_insertion_point(field:sentencepiece.ModelProto.SentencePiece.type) + pub type_: ::std::option::Option<::protobuf::EnumOrUnknown>, + // special fields + // @@protoc_insertion_point(special_field:sentencepiece.ModelProto.SentencePiece.special_fields) + pub special_fields: ::protobuf::SpecialFields, + } + + impl<'a> ::std::default::Default for &'a SentencePiece { + fn default() -> &'a SentencePiece { + ::default_instance() + } + } + + impl SentencePiece { + pub fn new() -> SentencePiece { + ::std::default::Default::default() + } + + // optional string piece = 1; + + pub fn piece(&self) -> &str { + match self.piece.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_piece(&mut self) { + self.piece = ::std::option::Option::None; + } + + pub fn has_piece(&self) -> bool { + self.piece.is_some() + } + + // Param is passed by value, moved + pub fn set_piece(&mut self, v: ::std::string::String) { + self.piece = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_piece(&mut self) -> &mut ::std::string::String { + if self.piece.is_none() { + self.piece = ::std::option::Option::Some(::std::string::String::new()); + } + self.piece.as_mut().unwrap() + } + + // Take field + pub fn take_piece(&mut self) -> ::std::string::String { + self.piece.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional float score = 2; + + pub fn score(&self) -> f32 { + self.score.unwrap_or(0.) + } + + pub fn clear_score(&mut self) { + self.score = ::std::option::Option::None; + } + + pub fn has_score(&self) -> bool { + self.score.is_some() + } + + // Param is passed by value, moved + pub fn set_score(&mut self, v: f32) { + self.score = ::std::option::Option::Some(v); + } + + // optional .sentencepiece.ModelProto.SentencePiece.Type type = 3; + + pub fn type_(&self) -> sentence_piece::Type { + match self.type_ { + Some(e) => e.enum_value_or(sentence_piece::Type::NORMAL), + None => sentence_piece::Type::NORMAL, + } + } + + pub fn clear_type_(&mut self) { + self.type_ = ::std::option::Option::None; + } + + pub fn has_type(&self) -> bool { + self.type_.is_some() + } + + // Param is passed by value, moved + pub fn set_type(&mut self, v: sentence_piece::Type) { + self.type_ = ::std::option::Option::Some(::protobuf::EnumOrUnknown::new(v)); + } + } + + impl ::protobuf::Message for SentencePiece { + const NAME: &'static str = "SentencePiece"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.piece = ::std::option::Option::Some(is.read_string()?); + }, + 21 => { + self.score = ::std::option::Option::Some(is.read_float()?); + }, + 24 => { + self.type_ = ::std::option::Option::Some(is.read_enum_or_unknown()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.piece.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.score { + my_size += 1 + 4; + } + if let Some(v) = self.type_ { + my_size += ::protobuf::rt::int32_size(3, v.value()); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.piece.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.score { + os.write_float(2, v)?; + } + if let Some(v) = self.type_ { + os.write_enum(3, ::protobuf::EnumOrUnknown::value(&v))?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> SentencePiece { + SentencePiece::new() + } + + fn clear(&mut self) { + self.piece = ::std::option::Option::None; + self.score = ::std::option::Option::None; + self.type_ = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static SentencePiece { + static instance: SentencePiece = SentencePiece { + piece: ::std::option::Option::None, + score: ::std::option::Option::None, + type_: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } + } + + /// Nested message and enums of message `SentencePiece` + pub mod sentence_piece { + #[derive(Clone,Copy,PartialEq,Eq,Debug,Hash)] + // @@protoc_insertion_point(enum:sentencepiece.ModelProto.SentencePiece.Type) + pub enum Type { + // @@protoc_insertion_point(enum_value:sentencepiece.ModelProto.SentencePiece.Type.NORMAL) + NORMAL = 1, + // @@protoc_insertion_point(enum_value:sentencepiece.ModelProto.SentencePiece.Type.UNKNOWN) + UNKNOWN = 2, + // @@protoc_insertion_point(enum_value:sentencepiece.ModelProto.SentencePiece.Type.CONTROL) + CONTROL = 3, + // @@protoc_insertion_point(enum_value:sentencepiece.ModelProto.SentencePiece.Type.USER_DEFINED) + USER_DEFINED = 4, + // @@protoc_insertion_point(enum_value:sentencepiece.ModelProto.SentencePiece.Type.BYTE) + BYTE = 6, + // @@protoc_insertion_point(enum_value:sentencepiece.ModelProto.SentencePiece.Type.UNUSED) + UNUSED = 5, + } + + impl ::protobuf::Enum for Type { + const NAME: &'static str = "Type"; + + fn value(&self) -> i32 { + *self as i32 + } + + fn from_i32(value: i32) -> ::std::option::Option { + match value { + 1 => ::std::option::Option::Some(Type::NORMAL), + 2 => ::std::option::Option::Some(Type::UNKNOWN), + 3 => ::std::option::Option::Some(Type::CONTROL), + 4 => ::std::option::Option::Some(Type::USER_DEFINED), + 6 => ::std::option::Option::Some(Type::BYTE), + 5 => ::std::option::Option::Some(Type::UNUSED), + _ => ::std::option::Option::None + } + } + + const VALUES: &'static [Type] = &[ + Type::NORMAL, + Type::UNKNOWN, + Type::CONTROL, + Type::USER_DEFINED, + Type::BYTE, + Type::UNUSED, + ]; + } + + // Note, `Default` is implemented although default value is not 0 + impl ::std::default::Default for Type { + fn default() -> Self { + Type::NORMAL + } + } + + } +} diff --git a/src/rllama_main.rs b/src/rllama_main.rs new file mode 100644 index 0000000..06d9543 --- /dev/null +++ b/src/rllama_main.rs @@ -0,0 +1,102 @@ +use crate::embedding::Embedding; +use crate::token_sampler::TokenSampler; +use crate::tokenizer::{TokenId, Tokenizer}; +use crate::transformer::Transformer; +use crate::unpickler; +use clap::Parser; +use std::io::Read; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + #[arg(long)] + model_path: String, + #[arg(long)] + tokenizer_path: String, + #[arg(long)] + prompt: String, + + #[arg(long)] + temperature: Option, + #[arg(long)] + top_p: Option, + #[arg(long)] + top_k: Option, +} + +pub fn main() -> Result<(), Box> { + let cli = Cli::parse(); + let model_path = cli.model_path; + let tokenizer_path = cli.tokenizer_path; + let prompt = cli.prompt; + + println!("Starting up. Loading tokenizer from {}...", tokenizer_path); + let tok = Tokenizer::load(tokenizer_path.as_str())?; + println!("Tokenizer loeaded. Loading model from {}...", model_path); + let mut fs = std::fs::File::open(model_path.as_str())?; + let mut bs = Vec::new(); + fs.read_to_end(&mut bs)?; + std::mem::drop(fs); + + // We chop off file name from model_path and append "data/" + let model_data_dir = model_path + .split("/") + .take(model_path.split("/").count() - 1) + .collect::>() + .join("/") + + "/data/"; + let result = unpickler::unpickle(&bs)?; + println!("Loading embeddings from {}...", model_data_dir); + let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?; + + println!("Loading transformer weights from {}...", model_data_dir); + let tr = Transformer::from_unpickled( + &result, + emb, + 4096, + 32, + 32, + 512, + 1e-6, + 32, + 128, + model_data_dir, + )?; + println!("All is loaded. Starting inference."); + + let mut toks_id: Vec = tok.tokenize_to_ids(prompt); + let mut prev_pos = 0; + let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50); + + if let Some(temperature) = cli.temperature { + token_sampler = token_sampler.temperature(temperature); + } + if let Some(top_p) = cli.top_p { + token_sampler = token_sampler.top_p(top_p); + } + if let Some(top_k) = cli.top_k { + token_sampler = token_sampler.top_k(top_k as usize); + } + + println!("Temperature: {}", token_sampler.get_temperature()); + println!("Top P: {}", token_sampler.get_top_p()); + println!("Top K: {}", token_sampler.get_top_k()); + + let mut caches = tr.make_caches(); + loop { + let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); + let highest_pred_idx = token_sampler.sample(&preds); + toks_id.push(highest_pred_idx as TokenId); + prev_pos = toks_id.len() - 1; + + let mut tok_str: String = "".to_string(); + for tok_id in toks_id.iter() { + if *tok_id == 1 { + continue; + } + let tok = tok.id_to_str(*tok_id); + tok_str = tok_str + tok.replace("▁", " ").as_str(); + } + println!("{}", tok_str); + } +} diff --git a/src/tensor.rs b/src/tensor.rs new file mode 100644 index 0000000..5e2106d --- /dev/null +++ b/src/tensor.rs @@ -0,0 +1,1301 @@ +use crate::unpickler; +use crate::unpickler::UnpicklingError; +use half::f16; +use rand::Rng; +use std::alloc::Layout; +use std::arch::x86_64::*; +use std::io::Read; +use std::path::{Path, PathBuf}; +use thiserror::Error; + +#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub struct TensorBuilder { + pub(crate) src_path: PathBuf, + pub(crate) dtype: TensorDType, + pub(crate) stride: i64, + pub(crate) rows: i64, + pub(crate) cols: i64, + pub(crate) nitems: i64, +} + +#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub enum TensorDType { + Float16, + Float32, +} + +#[derive(Error, Debug)] +pub enum TensorError { + #[error("IO error: {0}")] + IOError(#[from] std::io::Error), + #[error("Invalid stride: {0}")] + InvalidStride(i64), +} + +impl TensorDType { + fn bytes_per_item(&self) -> usize { + match self { + Self::Float16 => 2, + Self::Float32 => 4, + } + } +} + +#[derive(Debug)] +pub struct Tensor { + data: *mut u8, + dtype: TensorDType, + layout: Layout, + rows: i64, + cols: i64, + // Every matrix is allocated so that cols are rounded to the next multiple of 32. + // This lets us write AVX2 code without complicated checks. + capacity_cols: i64, +} + +unsafe impl Send for Tensor {} +unsafe impl Sync for Tensor {} + +impl Clone for Tensor { + fn clone(&self) -> Self { + unsafe { + let new_tensor = Tensor::uninitialized(self.rows, self.cols, self.dtype); + std::ptr::copy_nonoverlapping( + self.data, + new_tensor.data, + (self.rows * self.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, + ); + new_tensor + } + } +} + +impl Drop for Tensor { + fn drop(&mut self) { + unsafe { + if self.data != std::ptr::null_mut() { + std::alloc::dealloc(self.data, self.layout); + } + } + } +} + +fn compute_capacity_cols(cols: i64) -> i64 { + if cols % 8 == 0 { + cols + } else { + cols + 8 - cols % 8 + } +} + +#[inline] +fn horizontal_sum(mut ymm: __m256) -> f32 { + unsafe { + let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); + ymm = _mm256_add_ps(ymm, ymm2); + ymm = _mm256_hadd_ps(ymm, ymm); + ymm = _mm256_hadd_ps(ymm, ymm); + return _mm256_cvtss_f32(ymm); + } +} + +impl Tensor { + pub fn from_unpickled, S: AsRef>( + unpickled: &unpickler::Value, + name: S, + data_dir: P, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + let name: &str = name.as_ref(); + let val = unpickled + .get_str_key(name) + .ok_or(UnpicklingError::MissingField(name.to_string()))?; + let val = val + .to_tensor_builder() + .ok_or(UnpicklingError::InvalidTensorData)?; + let val = val.load(data_dir)?; + Ok(val) + } + + pub fn rows(&self) -> i64 { + self.rows + } + + pub fn cols(&self) -> i64 { + self.cols + } + + // Gets a value as f32 from the tensor. + #[inline] + pub fn get_f32(&self, row: i64, col: i64) -> f32 { + assert!( + row >= 0 && col >= 0 && row < self.rows && col < self.cols, + "Invalid index: {}, {} Size: {}, {}", + row, + col, + self.rows, + self.cols + ); + let idx = row * self.capacity_cols + col; + match self.dtype { + TensorDType::Float16 => { + let val: f16 = unsafe { *(self.data.add(idx as usize * 2) as *const f16) }; + val.to_f32() + } + TensorDType::Float32 => { + let val: f32 = unsafe { *(self.data.add(idx as usize * 4) as *const f32) }; + val + } + } + } + + // Sets a value from f32. The value is cast into whatever the tensor's dtype is. + #[inline] + pub fn set_f32(&mut self, row: i64, col: i64, val: f32) { + let idx = row * self.capacity_cols + col; + match self.dtype { + TensorDType::Float16 => { + let val: f16 = f16::from_f32(val); + unsafe { *(self.data.add(idx as usize * 2) as *mut f16) = val }; + } + TensorDType::Float32 => { + unsafe { *(self.data.add(idx as usize * 4) as *mut f32) = val }; + } + } + } + + // Converts the tensor to two-dimensional Vec. + // Meant for debugging and making it easy to print tensors. + pub fn to_vec(&self) -> Vec> { + let mut result = Vec::new(); + for row in 0..self.rows { + let mut row_vec = Vec::new(); + for col in 0..self.cols { + let val = self.get_f32(row, col); + row_vec.push(val); + } + result.push(row_vec); + } + result + } + + pub fn empty() -> Self { + Self { + data: std::ptr::null_mut(), + dtype: TensorDType::Float16, + layout: Layout::from_size_align(0, 0).unwrap(), + rows: 0, + cols: 0, + capacity_cols: 0, + } + } + + pub unsafe fn uninitialized(rows: i64, cols: i64, dtype: TensorDType) -> Self { + if rows == 0 || cols == 0 { + let mut tensor = Self::empty(); + tensor.rows = rows; + tensor.cols = cols; + return tensor; + } + // Rouns up cols to 8 + let capacity_cols = compute_capacity_cols(cols); + let nitems = rows * capacity_cols; + let layout = + Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap(); + let data = unsafe { std::alloc::alloc(layout) }; + if data == std::ptr::null_mut() { + panic!("Failed to allocate tensor"); + } + // Even though we are uninitialized, we should zero out the extra space between the + // columns. + // Otherwise there might be problems later as other operations assume it is zeroed. + for extra_col in cols..capacity_cols { + for row in 0..rows { + let idx = row * capacity_cols + extra_col; + match dtype { + TensorDType::Float16 => { + let val: f16 = f16::from_f32(0.0); + unsafe { *(data.add(idx as usize * 2) as *mut f16) = val }; + } + TensorDType::Float32 => { + unsafe { *(data.add(idx as usize * 4) as *mut f32) = 0.0 }; + } + } + } + } + + Self { + data, + dtype, + rows, + cols, + capacity_cols, + layout, + } + } + + pub fn full(rows: i64, cols: i64, dtype: TensorDType, value: f32) -> Self { + let mut tensor = unsafe { Tensor::uninitialized(rows, cols, dtype) }; + for row in 0..rows { + for col in 0..cols { + tensor.set_f32(row, col, value); + } + } + tensor + } + + // Runs softmax on row dimension. + pub fn softmax(&self) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + let mut sum = 0.0; + for col in 0..self.cols { + let val = self.get_f32(row, col); + sum += val.exp(); + } + for col in 0..self.cols { + let val = self.get_f32(row, col); + result.set_f32(row, col, val.exp() / sum); + } + } + result + } + + pub fn full_triu(rows: i64, cols: i64, start_pos: i64, dtype: TensorDType, value: f32) -> Self { + let mut tensor = unsafe { Tensor::uninitialized(rows, cols, dtype) }; + for row in 0..rows { + for col in 0..cols { + if col >= row + start_pos { + tensor.set_f32(row, col, value); + } else { + tensor.set_f32(row, col, 0.0); + } + } + } + tensor + } + + // Computes mean for each row, so that columns become 1. + pub fn mean_cols(&self) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) }; + for row in 0..self.rows { + let mut sum = 0.0; + for col in 0..self.cols { + sum += self.get_f32(row, col); + } + result.set_f32(row, 0, sum / self.cols as f32); + } + result + } + + pub fn mean(&self) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(1, 1, self.dtype) }; + let mut sum = 0.0; + for row in 0..self.rows { + for col in 0..self.cols { + sum += self.get_f32(row, col); + } + } + result.set_f32(0, 0, sum / (self.rows * self.cols) as f32); + result + } + + pub fn pow(&self, power: f32) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col); + result.set_f32(row, col, val.powf(power)); + } + } + result + } + + pub fn sqrt(&self) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col); + result.set_f32(row, col, val.sqrt()); + } + } + result + } + + pub fn rsqrt(&self) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col); + result.set_f32(row, col, 1.0 / val.sqrt()); + } + } + result + } + + pub fn add(&self, other: &Tensor) -> Tensor { + if self.rows() != other.rows() || self.cols() != other.cols() { + panic!( + "add: Tensors must have the same shape, left: {}x{} right: {}x{}", + self.rows(), + self.cols(), + other.rows(), + other.cols() + ); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col) + other.get_f32(row, col); + result.set_f32(row, col, val); + } + } + result + } + + pub fn add_scalar(&self, scalar: f32) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col) + scalar; + result.set_f32(row, col, val); + } + } + result + } + + pub fn scalar_multiply_f32(&self, scalar: f32) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col) * scalar; + result.set_f32(row, col, val); + } + } + result + } + + pub fn scalar_multiply_broadcast(&self, other: &Tensor) -> Tensor { + if other.cols != 1 { + panic!("Invalid scalar broadcast"); + } + if other.rows != self.rows { + panic!("Invalid scalar broadcast"); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + let scalar = other.get_f32(row, 0); + for col in 0..self.cols { + let val = self.get_f32(row, col) * scalar; + result.set_f32(row, col, val); + } + } + result + } + + pub fn scalar_product(&self, other: &Tensor) -> Tensor { + if other.cols != 1 || other.rows != 1 { + panic!("Invalid scalar product"); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + let scalar = other.get_f32(0, 0); + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col) * scalar; + result.set_f32(row, col, val); + } + } + result + } + + pub fn hadamard_product_broadcast(&self, other: &Tensor) -> Tensor { + if self.cols != other.cols { + panic!("Invalid hadamard product broadcast"); + } + if other.rows != 1 { + panic!("Invalid hadamard product broadcast"); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col) * other.get_f32(0, col); + result.set_f32(row, col, val); + } + } + result + } + + pub fn hadamard_product(&self, other: &Tensor) -> Tensor { + if self.cols != other.cols || self.rows != other.rows { + panic!( + "Invalid hadamard product: incompatible shapes, {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col) * other.get_f32(row, col); + result.set_f32(row, col, val); + } + } + result + } + + pub fn concat(pieces: &[&Tensor]) -> Tensor { + if pieces.len() == 0 { + return Tensor::empty(); + } + let mut total_rows: i64 = 0; + let expected_cols: i64 = pieces[0].cols; + let expected_dtype: TensorDType = pieces[0].dtype; + for piece in pieces { + if piece.cols != expected_cols { + panic!("Invalid tensor concatenation, wrong number of columns"); + } + if piece.dtype != expected_dtype { + panic!("Invalid tensor concatenation, wrong dtype"); + } + total_rows += piece.rows; + } + let mut result = + unsafe { Tensor::uninitialized(total_rows, expected_cols, pieces[0].dtype) }; + let mut row_offset = 0; + for piece in pieces { + for row in 0..piece.rows { + for col in 0..piece.cols { + let val = piece.get_f32(row, col); + result.set_f32(row_offset + row, col, val); + } + } + row_offset += piece.rows; + } + result + } + + pub fn silu(&self) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col); + let val = val / (1.0 + (-val).exp()); + result.set_f32(row, col, val); + } + } + result + } + + pub fn transpose(&self) -> Tensor { + let mut result = unsafe { Tensor::uninitialized(self.cols, self.rows, self.dtype) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col); + result.set_f32(col, row, val); + } + } + result + } + + /// Slow, naive matrix multiplication. + /// + /// This is used as a reference to test correctness of other matrix multiplications. + pub fn matrix_mul_naive(&self, other: &Tensor) -> Tensor { + if self.cols != other.rows { + panic!( + "Invalid matrix multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, other.cols, self.dtype) }; + for row in 0..self.rows { + for col in 0..other.cols { + let mut sum = 0.0; + for i in 0..self.cols { + sum += self.get_f32(row, i) * other.get_f32(i, col); + } + result.set_f32(row, col, sum); + } + } + result + } + + pub fn matrix_mul(&self, other: &Tensor) -> Tensor { + if self.cols != other.rows { + panic!( + "Invalid matrix multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + if self.rows == 1 { + return self.vector_matrix_mul(other); + } + if other.cols == 1 { + return self.matrix_vector_mul(other); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, other.cols, self.dtype) }; + result.matrix_mul_inplace(self, other); + result + } + + pub fn matrix_mul_transposed(&self, other: &Tensor) -> Tensor { + if self.cols != other.cols { + panic!( + "Invalid matrix transposed multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.cols, other.rows + ); + } + if other.rows == 1 { + return self.matrix_vector_mul_transposed(other); + } + let mut result = unsafe { Tensor::uninitialized(self.rows, other.rows, self.dtype) }; + result.matrix_mul_inplace_transposed(self, other); + result + } + + /// Matrix multiplication done in-place + pub fn matrix_mul_inplace(&mut self, src: &Tensor, other: &Tensor) { + if src.cols != other.rows { + panic!( + "Invalid matrix multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + if src.dtype != other.dtype { + panic!("Invalid matrix multiplication, different dtypes"); + } + if self.rows != src.rows { + panic!("Invalid matrix multiplication, different number of rows"); + } + if self.cols != other.cols { + panic!("Invalid matrix multiplication, different number of cols"); + } + + match src.dtype { + TensorDType::Float32 => { + // not actual cache line size, but this represents 8 floats which is the number we can + // operate with AVX2 + const CACHE_LINE_SIZE: usize = 32; + const ITEMS_PER_CACHE_LINE: usize = CACHE_LINE_SIZE / std::mem::size_of::(); + + let tgt_data: *mut f32 = self.data as *mut f32; + unsafe { + std::ptr::write_bytes( + tgt_data, + 0, + self.rows as usize * self.capacity_cols as usize, + ); + } + let src_data: *const f32 = src.data as *const f32; + let other_data: *const f32 = other.data as *const f32; + + let src_rows: usize = src.rows as usize; + let other_cols: usize = other.cols as usize; + let src_cols: usize = src.cols as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + let mut row: usize = 0; + let mut col: usize; + let mut k: usize; + + unsafe { + while row < src_rows { + col = 0; + while col < other_cols { + k = 0; + while k < src_cols { + for i2 in row..std::cmp::min(row + ITEMS_PER_CACHE_LINE, src_rows) { + let i2_self_cols = i2 * self_cols_capacity; + let i2_src_cols = i2 * src_cols_capacity; + for k2 in k..std::cmp::min(k + ITEMS_PER_CACHE_LINE, src_cols) { + let other_value8: __m256 = _mm256_loadu_ps( + other_data.add(k2 * other_cols_capacity + col), + ); + let src_value8_broadcast: __m256 = + _mm256_broadcast_ss(&*src_data.add(i2_src_cols + k2)); + let tgt_value8: __m256 = + _mm256_loadu_ps(tgt_data.add(i2_self_cols + col)); + let result8: __m256 = _mm256_fmadd_ps( + src_value8_broadcast, + other_value8, + tgt_value8, + ); + _mm256_storeu_ps(tgt_data.add(i2_self_cols + col), result8); + } + } + k += ITEMS_PER_CACHE_LINE; + } + col += ITEMS_PER_CACHE_LINE; + } + row += ITEMS_PER_CACHE_LINE; + } + } + } + TensorDType::Float16 => unsafe { + // Even with conversion, float16 is much slower than float32 + const CACHE_LINE_SIZE: usize = 16; + const ITEMS_PER_CACHE_LINE: usize = CACHE_LINE_SIZE / std::mem::size_of::(); + assert!(src.rows as usize % ITEMS_PER_CACHE_LINE == 0); + assert!(src.cols as usize % ITEMS_PER_CACHE_LINE == 0); + assert!(other.cols as usize % ITEMS_PER_CACHE_LINE == 0); + assert!(other.rows as usize % ITEMS_PER_CACHE_LINE == 0); + + let tgt_data: *mut f16 = self.data as *mut f16; + std::ptr::write_bytes(tgt_data, 0, self.rows as usize * self.cols as usize); + let src_data: *const f16 = src.data as *const f16; + let other_data: *const f16 = other.data as *const f16; + + let src_rows: usize = src.rows as usize; + let other_cols: usize = other.cols as usize; + let src_cols: usize = src.cols as usize; + let self_cols: usize = self.cols as usize; + + let mut row: usize = 0; + let mut col: usize; + let mut k: usize; + + while row < src_rows { + col = 0; + while col < other_cols { + k = 0; + while k < src_cols { + for i2 in row..row + ITEMS_PER_CACHE_LINE { + let i2_self_cols = i2 * self_cols; + let i2_src_cols = i2 * src_cols; + for k2 in k..k + ITEMS_PER_CACHE_LINE { + let other_value8: __m256 = _mm256_cvtph_ps(_mm_loadu_si128( + other_data.add(k2 * other_cols + col) as *const _, + )); + let src_value8: f16 = *src_data.add(i2_src_cols + k2); + let src_value8_broadcast: __m256 = + _mm256_broadcast_ss(&src_value8.to_f32()); + let tgt_value8: __m256 = _mm256_cvtph_ps(_mm_loadu_si128( + tgt_data.add(i2_self_cols + col) as *const _, + )); + let result8: __m256 = _mm256_fmadd_ps( + src_value8_broadcast, + other_value8, + tgt_value8, + ); + let result8_packed: __m128i = _mm256_cvtps_ph(result8, 0); + _mm_storeu_si128( + tgt_data.add(i2_self_cols + col) as *mut _, + result8_packed, + ); + } + } + k += ITEMS_PER_CACHE_LINE; + } + col += ITEMS_PER_CACHE_LINE; + } + row += ITEMS_PER_CACHE_LINE; + } + }, + } + } + + /// Matrix multiplication done in-place, but the second matrix is transposed. + /// With this, you can avoid using .transpose() on the second matrix. + pub fn matrix_mul_inplace_transposed(&mut self, src: &Tensor, other: &Tensor) { + if src.cols != other.cols { + panic!( + "Invalid matrix multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + if src.dtype != other.dtype { + panic!("Invalid matrix multiplication, different dtypes"); + } + if self.rows != src.rows { + panic!("Invalid matrix multiplication, different number of rows"); + } + if self.cols != other.rows { + panic!("Invalid matrix multiplication, different number of cols"); + } + + match src.dtype { + TensorDType::Float32 => { + const CACHE_LINE_SIZE: usize = 32; + const ITEMS_PER_CACHE_LINE: usize = CACHE_LINE_SIZE / std::mem::size_of::(); + + let tgt_data: *mut f32 = self.data as *mut f32; + unsafe { + std::ptr::write_bytes( + tgt_data, + 0, + self.rows as usize * self.capacity_cols as usize, + ); + } + let src_data: *const f32 = src.data as *const f32; + let other_data: *const f32 = other.data as *const f32; + + let src_cols: usize = src.cols as usize; + let self_rows: usize = self.rows as usize; + let self_cols: usize = self.cols as usize; + let other_cols_capacity: usize = other.capacity_cols as usize; + let src_cols_capacity: usize = src.capacity_cols as usize; + let self_cols_capacity: usize = self.capacity_cols as usize; + + let src_cols_its = if src_cols % ITEMS_PER_CACHE_LINE == 0 { + src_cols / ITEMS_PER_CACHE_LINE + } else { + src_cols / ITEMS_PER_CACHE_LINE + 1 + }; + + unsafe { + for row in 0..self_rows { + let row = row as usize; + for col in 0..self_cols { + let mut target8: __m256 = _mm256_setzero_ps(); + for p in 0..src_cols_its { + let src8: __m256 = _mm256_loadu_ps( + src_data + .add(row * src_cols_capacity + p * ITEMS_PER_CACHE_LINE), + ); + let other8: __m256 = _mm256_loadu_ps( + other_data + .add(col * other_cols_capacity + p * ITEMS_PER_CACHE_LINE), + ); + target8 = _mm256_fmadd_ps(src8, other8, target8); + } + let target: f32 = horizontal_sum(target8); + *tgt_data.add(row * self_cols_capacity + col) = target; + } + } + } + } + TensorDType::Float16 => unimplemented!(), + } + } + + // Computes matrix multiplication assuming that the number of rows on the latter matrix is 1. + // + // AxB @ Cx1 = Ax1 + pub fn matrix_vector_mul(&self, other: &Tensor) -> Tensor { + // TODO: this function is not optimized. + if self.cols != other.rows { + panic!( + "Invalid matrix-vector multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + assert_eq!(other.cols, 1); + assert_eq!(other.dtype, self.dtype); + assert_eq!(self.dtype, TensorDType::Float32); + + let mut result = unsafe { Tensor::uninitialized(self.rows, 1, self.dtype) }; + for row in 0..self.rows { + let mut sum = 0.0; + for col in 0..self.cols { + sum += self.get_f32(row, col) * other.get_f32(col, 0); + } + result.set_f32(row, 0, sum); + } + result + } + + /// Same as matrix_vector_mul, but right side is assumed to be transposed. + pub fn matrix_vector_mul_transposed(&self, other: &Tensor) -> Tensor { + if self.cols != other.cols { + panic!( + "Invalid matrix-vector transposed multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + assert_eq!(other.rows, 1); + assert_eq!(other.dtype, self.dtype); + assert_eq!(self.dtype, TensorDType::Float32); + + unsafe { + let mut result = Tensor::uninitialized(self.rows, 1, self.dtype); + let col_its: usize = if self.cols % 8 == 0 { + (self.cols / 8) as usize + } else { + (self.cols / 8 + 1) as usize + }; + let self_data: *const f32 = self.data as *const f32; + let other_data: *const f32 = other.data as *const f32; + for row in 0..self.rows { + let mut sum8: __m256 = _mm256_setzero_ps(); + for col in 0..col_its { + let col = (col * 8) as usize; + let left_side8 = + _mm256_loadu_ps(self_data.add((row * self.capacity_cols) as usize + col)); + let right_side8 = _mm256_loadu_ps(other_data.add(col)); + sum8 = _mm256_fmadd_ps(left_side8, right_side8, sum8); + } + let sum: f32 = horizontal_sum(sum8); + result.set_f32(row, 0, sum); + } + result + } + } + + // Computes matrix multiplication assuming left side has number of rows as 1 + pub fn vector_matrix_mul(&self, other: &Tensor) -> Tensor { + if self.cols != other.rows { + panic!( + "Invalid matrix-vector multiplication {}x{} vs {}x{}", + self.rows, self.cols, other.rows, other.cols + ); + } + assert_eq!(self.rows, 1); + let mut result = unsafe { Tensor::uninitialized(1, other.cols, self.dtype) }; + for col in 0..other.cols { + let mut sum = 0.0; + for row in 0..self.cols { + sum += self.get_f32(0, row) * other.get_f32(row, col); + } + result.set_f32(0, col, sum); + } + result + } + + pub fn random(rows: i64, cols: i64, dtype: TensorDType) -> Self { + let mut result = unsafe { Tensor::uninitialized(rows, cols, dtype) }; + let mut rng = rand::thread_rng(); + for row in 0..rows { + for col in 0..cols { + result.set_f32(row, col, rng.gen_range(-1.0..1.0)); + } + } + result + } + + pub fn eye(sz: i64, dtype: TensorDType) -> Self { + let mut result = unsafe { Tensor::uninitialized(sz, sz, dtype) }; + for row in 0..sz { + for col in 0..sz { + result.set_f32(row, col, if row == col { 1.0 } else { 0.0 }); + } + } + result + } + + pub fn zeros(rows: i64, cols: i64, dtype: TensorDType) -> Self { + if rows == 0 || cols == 0 { + let mut tensor = Self::empty(); + tensor.rows = rows; + tensor.cols = cols; + return tensor; + } + let capacity_cols = compute_capacity_cols(cols); + let nitems = rows * capacity_cols; + let layout = + Layout::from_size_align((nitems as usize) * dtype.bytes_per_item(), 32).unwrap(); + let data = unsafe { std::alloc::alloc_zeroed(layout) }; + if data == std::ptr::null_mut() { + panic!("Failed to allocate tensor"); + } + Self { + data, + dtype, + rows, + cols, + capacity_cols, + layout, + } + } + + pub fn clip_cols(&self, cols: usize) -> Tensor { + if cols == 0 { + return Self::empty(); + } + assert!(cols as i64 <= self.cols); + + let result = unsafe { Tensor::uninitialized(self.rows, cols as i64, self.dtype) }; + for row in 0..self.rows { + unsafe { + std::ptr::copy_nonoverlapping( + self.data.add( + (row * self.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, + ), + result.data.add( + (row * result.capacity_cols * self.dtype.bytes_per_item() as i64) as usize, + ), + cols * self.dtype.bytes_per_item(), + ); + } + } + result + } + + pub fn view(&self, rows: i64, cols: i64) -> Tensor { + if rows * cols != self.rows * self.cols { + panic!("Invalid tensor view"); + } + if rows == self.rows { + return self.clone(); + } + unsafe { + let mut result = Self::zeros(rows, cols, self.dtype); + result.rows = rows; + result.cols = cols; + match self.dtype { + TensorDType::Float16 => { + let mut tgt_row: usize = 0; + let mut tgt_col: usize = 0; + for src_row in 0..self.rows { + for src_col in 0..self.cols { + let idx = (src_row * self.capacity_cols + src_col) as usize; + let v: f16 = *(self.data.add(idx * 2) as *const f16); + *(result + .data + .add((tgt_row * result.capacity_cols as usize + tgt_col) * 2) + as *mut f16) = v; + tgt_col += 1; + if tgt_col == cols as usize { + tgt_col = 0; + tgt_row += 1; + } + } + } + } + TensorDType::Float32 => { + let mut tgt_row: usize = 0; + let mut tgt_col: usize = 0; + for src_row in 0..self.rows { + for src_col in 0..self.cols { + let idx = (src_row * self.capacity_cols + src_col) as usize; + let v: f32 = *(self.data.add(idx * 4) as *const f32); + *(result + .data + .add((tgt_row * result.capacity_cols as usize + tgt_col) * 4) + as *mut f32) = v; + tgt_col += 1; + if tgt_col == cols as usize { + tgt_col = 0; + tgt_row += 1; + } + } + } + } + } + result + } + } + + pub fn to_f32(&self) -> Tensor { + if self.dtype == TensorDType::Float32 { + return self.clone(); + } + + let mut result = + unsafe { Tensor::uninitialized(self.rows, self.cols, TensorDType::Float32) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col); + result.set_f32(row, col, val); + } + } + result + } + + pub fn to_f16(&self) -> Tensor { + if self.dtype == TensorDType::Float16 { + return self.clone(); + } + + let mut result = + unsafe { Tensor::uninitialized(self.rows, self.cols, TensorDType::Float16) }; + for row in 0..self.rows { + for col in 0..self.cols { + let val = self.get_f32(row, col); + result.set_f32(row, col, val); + } + } + result + } + + pub fn row(&self, row: i64) -> Tensor { + if row < 0 || row > self.rows { + panic!("Invalid row index"); + } + + let result = unsafe { Tensor::uninitialized(1, self.cols, self.dtype) }; + unsafe { + std::ptr::copy_nonoverlapping( + self.data + .add((row * self.capacity_cols) as usize * self.dtype.bytes_per_item()), + result.data, + self.cols as usize * self.dtype.bytes_per_item(), + ); + } + result + } +} + +impl TensorBuilder { + pub fn load>(&self, data_dir: P) -> Result { + let data_dir: &Path = data_dir.as_ref(); + if self.stride < 1 { + return Err(TensorError::InvalidStride(self.stride)); + } + let tensor = unsafe { Tensor::uninitialized(self.rows, self.cols, self.dtype) }; + assert_eq!(self.dtype, TensorDType::Float16); + let path = data_dir.join(&self.src_path); + + let mut f = std::fs::File::open(&path).unwrap(); + let mut cursor: usize = 0; + let mut buf: Vec = vec![0; self.cols as usize * 2]; + for _row in 0..self.rows { + f.read_exact(&mut buf)?; + unsafe { + std::ptr::copy_nonoverlapping(buf.as_ptr(), tensor.data.add(cursor), buf.len()); + } + cursor = cursor + (tensor.capacity_cols as usize * 2); + } + Ok(tensor.to_f32()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_relative_eq; + + #[test] + fn mat_mul_transposed_agrees_with_regular_mat_mul() { + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + let a = rng.gen_range(8..64); + let b = rng.gen_range(8..64); + let r = rng.gen_range(8..64); + + // Make matrixes AxR and RxB + let a = Tensor::random(a, r, TensorDType::Float32); + let b = Tensor::random(r, b, TensorDType::Float32); + let b_transposed = b.transpose(); + + let c = a.matrix_mul(&b); + let c2 = a.matrix_mul_transposed(&b_transposed); + + assert_eq!(c.rows, c2.rows); + assert_eq!(c.cols, c2.cols); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + } + } + } + } + + #[test] + fn view_preserves_values() { + fn test_with_type(dtype: TensorDType) { + let mut rng = rand::thread_rng(); + + for _ in 0..1000 { + let mut a: i64 = 0; + let mut b: i64 = 0; + let mut c: i64 = 0; + let mut d: i64 = 0; + loop { + a = rng.gen_range(8..64); + b = rng.gen_range(8..64); + c = rng.gen_range(8..64); + if (a * b) % c != 0 { + continue; + } + d = (a * b) / c; + break; + } + + let tensor_left = Tensor::random(a, b, dtype); + let tensor_right = tensor_left.view(c, d); + + assert_eq!( + tensor_left.cols() * tensor_left.rows(), + tensor_right.cols() * tensor_right.rows() + ); + + let mut cursor: usize = 0; + let mut left_row: usize = 0; + let mut left_col: usize = 0; + let mut right_row: usize = 0; + let mut right_col: usize = 0; + + while cursor < tensor_left.cols() as usize * tensor_left.rows() as usize { + let left_value = tensor_left.get_f32(left_row as i64, left_col as i64); + let right_value = tensor_right.get_f32(right_row as i64, right_col as i64); + assert_eq!( + left_value, right_value, + "left: {:?}, right: {:?} dtype {:?}", + tensor_left, tensor_right, dtype + ); + left_col += 1; + if left_col == tensor_left.cols() as usize { + left_col = 0; + left_row += 1; + } + right_col += 1; + if right_col == tensor_right.cols() as usize { + right_col = 0; + right_row += 1; + } + cursor += 1; + } + } + } + test_with_type(TensorDType::Float32); + test_with_type(TensorDType::Float16); + } + + #[test] + fn mat_vector_mul_matches_naive_mat_mul() { + let mut rng = rand::thread_rng(); + for _ in 0..50 { + let r = rng.gen_range(1..100); + let r2 = rng.gen_range(1..100); + + let a = Tensor::random(r, r2, TensorDType::Float32); + let b = Tensor::random(r2, 1, TensorDType::Float32); + + let c = a.matrix_mul_naive(&b); + let c2 = a.matrix_vector_mul(&b); + + assert_eq!(c.rows(), c2.rows()); + assert_eq!(c.cols(), c2.cols()); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + } + } + } + } + + #[test] + fn mat_vector_transposed_mul_matches_naive_mat_mul() { + let mut rng = rand::thread_rng(); + for _ in 0..50 { + let r = rng.gen_range(1..100); + let r2 = rng.gen_range(1..100); + + let a = Tensor::random(r, r2, TensorDType::Float32); + let b = Tensor::random(1, r2, TensorDType::Float32); + + let c = a.matrix_mul_naive(&b.transpose()); + let c2 = a.matrix_vector_mul_transposed(&b); + + assert_eq!(c.rows(), c2.rows()); + assert_eq!(c.cols(), c2.cols()); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + } + } + } + } + + #[test] + fn naive_mat_mul_and_fast_are_same_f32_random_sizes() { + let mut rng = rand::thread_rng(); + for _ in 0..50 { + let left_rows = rng.gen_range(1..100); + let right_cols = rng.gen_range(1..100); + let shared_len = rng.gen_range(1..100); + + let a = Tensor::random(left_rows, shared_len, TensorDType::Float32); + let b = Tensor::random(shared_len, right_cols, TensorDType::Float32); + + let c = a.matrix_mul_naive(&b); + let c2 = a.matrix_mul(&b); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + } + } + } + } + + #[test] + fn naive_mat_mul_and_fast_are_same_f32() { + for _ in 0..50 { + let a = Tensor::random(16, 32, TensorDType::Float32); + let b = Tensor::random(32, 16, TensorDType::Float32); + + let c = a.matrix_mul_naive(&b); + let c2 = a.matrix_mul(&b); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + } + } + } + } + + #[test] + fn mat_mul_with_itself_is_correct_f32() { + for _ in 0..50 { + let a = Tensor::random(16, 16, TensorDType::Float32); + let c = a.matrix_mul_naive(&a); + let c2 = a.matrix_mul(&a); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-5); + } + } + } + } + + #[test] + fn naive_mat_mul_and_fast_are_same_f16() { + for _ in 0..50 { + let a = Tensor::random(16, 32, TensorDType::Float16); + let b = Tensor::random(32, 16, TensorDType::Float16); + + let c = a.matrix_mul_naive(&b); + let c2 = a.matrix_mul(&b); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1); + } + } + } + } + + #[test] + fn mat_mul_with_itself_is_correct_f16() { + for _ in 0..50 { + let a = Tensor::random(16, 16, TensorDType::Float16); + let c = a.matrix_mul_naive(&a); + let c2 = a.matrix_mul(&a); + + for row in 0..c.rows { + for col in 0..c.cols { + assert_relative_eq!(c.get_f32(row, col), c2.get_f32(row, col), epsilon = 1e-1); + } + } + } + } + + #[test] + fn clip_cols_works() { + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + let rows = rng.gen_range(1..100); + let cols = rng.gen_range(2..100); + let new_cols = rng.gen_range(1..=cols); + + let a = Tensor::random(rows, cols, TensorDType::Float32); + let a_clipped = a.clip_cols(new_cols as usize); + + assert_eq!(a.rows(), a_clipped.rows()); + assert_eq!(a_clipped.cols(), new_cols); + + for row in 0..a_clipped.rows { + for col in 0..a_clipped.cols { + assert_eq!(a.get_f32(row, col), a_clipped.get_f32(row, col)); + } + } + } + } +} diff --git a/src/token_sampler.rs b/src/token_sampler.rs new file mode 100644 index 0000000..36abc0f --- /dev/null +++ b/src/token_sampler.rs @@ -0,0 +1,85 @@ +use crate::tensor::Tensor; +use crate::tokenizer::TokenId; +use rand::Rng; + +pub struct TokenSampler { + temperature: f32, + top_p: f32, + top_k: usize, +} + +impl TokenSampler { + pub fn new() -> Self { + Self { + temperature: 0.8, + top_p: 1.0, + top_k: 1, // same as argmax + } + } + + pub fn get_temperature(&self) -> f32 { + self.temperature + } + + pub fn get_top_p(&self) -> f32 { + self.top_p + } + + pub fn get_top_k(&self) -> usize { + self.top_k + } + + pub fn temperature(self, temperature: f32) -> Self { + Self { + temperature, + ..self + } + } + + pub fn top_p(self, top_p: f32) -> Self { + Self { top_p, ..self } + } + + pub fn top_k(self, top_k: usize) -> Self { + Self { top_k, ..self } + } + + pub fn sample(&self, logits: &Tensor) -> TokenId { + let nrows = logits.rows(); + assert!(logits.cols() == 1); + let mut logits = logits.transpose(); + if self.temperature > 0.0 { + logits = logits.scalar_multiply_f32(1.0 / self.temperature); + logits = logits.softmax(); + } + + let mut logitsf: Vec<(TokenId, f32)> = Vec::with_capacity(nrows as usize); + for i in 0..nrows { + logitsf.push((i as TokenId, logits.get_f32(0, i))); + } + logitsf.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + logitsf.truncate(self.top_k as usize); + let mut p_accum: f32 = 0.0; + for (idx, v) in logitsf.iter().enumerate() { + p_accum += v.1; + if p_accum >= self.top_p { + logitsf.truncate(idx + 1); + break; + } + } + let mut total_p: f32 = 0.0; + for v in logitsf.iter() { + total_p += v.1; + } + let mut rng = rand::thread_rng(); + let p: f32 = rng.gen_range(0.0..total_p); + p_accum = 0.0; + for v in logitsf.into_iter() { + p_accum += v.1; + if p_accum >= p { + return v.0; + } + } + 0 + } +} diff --git a/src/tokenizer.rs b/src/tokenizer.rs new file mode 100644 index 0000000..5b0f647 --- /dev/null +++ b/src/tokenizer.rs @@ -0,0 +1,156 @@ +use crate::protomodels::sentencepiece_model::model_proto::sentence_piece; +use crate::protomodels::sentencepiece_model::ModelProto; +use protobuf::Message; +use std::collections::BTreeMap; +use std::io::Read; +use std::path::Path; +use thiserror::Error; + +pub type TokenId = i32; + +#[derive(Clone, Debug)] +pub struct Tokenizer { + pieces: BTreeMap, +} + +#[derive(Clone, Debug, Copy, Eq, Ord, PartialEq, PartialOrd)] +pub enum PieceType { + Normal, + Unknown, + Control, + UserDefined, + Byte, + Unused, +} + +#[derive(Clone, Debug)] +pub struct Piece { + _tp: PieceType, + // piece: String this is in the BTreeMap that holds the pieces + _score: f32, + idx: usize, +} + +#[derive(Error, Debug)] +pub enum TokenizerError { + #[error("IO error")] + IoError(#[from] std::io::Error), + #[error("Protobuf error")] + ProtobufError(#[from] protobuf::Error), + #[error("Unknown piece type")] + UnknownPieceType(String), +} + +impl Tokenizer { + pub fn load>(path: P) -> Result { + let mut fs = std::fs::File::open(path)?; + let mut buffer = Vec::new(); + fs.read_to_end(&mut buffer)?; + std::mem::drop(fs); + let model = ModelProto::parse_from_bytes(&buffer)?; + + let mut pieces = BTreeMap::new(); + for (idx, piece) in model.pieces.iter().enumerate() { + let piece_str = piece.piece.clone(); + if piece_str.is_none() { + continue; + } + let piece_str = piece_str.unwrap(); + let piece_type = match piece.type_ { + None => sentence_piece::Type::NORMAL, + Some(v) => match v.enum_value() { + Err(_) => return Err(TokenizerError::UnknownPieceType(piece_str)), + Ok(v) => v, + }, + }; + + let score = piece.score.unwrap_or(0.0); + let tp = if piece_type == sentence_piece::Type::NORMAL { + PieceType::Normal + } else if piece_type == sentence_piece::Type::UNKNOWN { + PieceType::Unknown + } else if piece_type == sentence_piece::Type::CONTROL { + PieceType::Control + } else if piece_type == sentence_piece::Type::USER_DEFINED { + PieceType::UserDefined + } else if piece_type == sentence_piece::Type::BYTE { + PieceType::Byte + } else if piece_type == sentence_piece::Type::UNUSED { + PieceType::Unused + } else { + return Err(TokenizerError::UnknownPieceType(piece_str)); + }; + pieces.insert( + piece_str, + Piece { + _tp: tp, + _score: score, + idx, + }, + ); + } + + Ok(Tokenizer { pieces }) + } + + // Gives a string for a token id. + // Panics if the id is out of range. + pub fn id_to_str(&self, id: i32) -> &str { + let id = id as usize; + for (piece_str, piece_info) in self.pieces.iter() { + if piece_info.idx == id { + return piece_str; + } + } + panic!("id out of range"); + } + + // Converts a string to a Vec<&str> + // You may want to use tokenize_to_ids instead. + // + // This will not add start or end tokens; only the string is processed. + // + // I noticed LLaMa code adds an extra space character at the beginning of the string, this + // function does not do that either. + pub fn tokenize_to_pieces>(&self, s: S) -> Vec<&str> { + let mut s: &str = s.as_ref(); + let mut result: Vec<&str> = Vec::new(); + + // Very naive matching + while !s.is_empty() { + let mut best_candidate: &str = ""; + let mut best_candidate_len: usize = 0; + let mut skip_s: &str = ""; + for (piece_str, _piece_info) in self.pieces.iter() { + if s.starts_with(piece_str) && best_candidate_len < piece_str.len() { + best_candidate = piece_str; + best_candidate_len = piece_str.len(); + skip_s = &s[piece_str.len()..]; + } + } + if best_candidate_len == 0 { + // Skip token. + s = s.get(1..).unwrap_or(""); + } else { + result.push(best_candidate); + s = skip_s; + } + } + result + } + + pub fn tokenize_to_ids>(&self, s: S) -> Vec { + let mut s: String = format!("▁{}", s.as_ref()); + // Replace all space characters with a special token. + s = s.replace(" ", "▁"); + + let pieces = self.tokenize_to_pieces(s); + let mut result = Vec::new(); + result.push(1); // start token + for piece in pieces { + let piece_info = self.pieces.get(piece).unwrap(); + result.push(piece_info.idx as i32); + } + result + } +} diff --git a/src/transformer.rs b/src/transformer.rs new file mode 100644 index 0000000..06c69c0 --- /dev/null +++ b/src/transformer.rs @@ -0,0 +1,546 @@ +use crate::embedding::Embedding; +use crate::tensor::{Tensor, TensorDType}; +use crate::tokenizer::TokenId; +use crate::unpickler; +use crate::unpickler::UnpicklingError; +use indicatif::ProgressBar; +use num_complex::Complex; +use rayon::prelude::*; +use std::path::Path; +use std::sync::{Arc, RwLock}; + +type FreqsCis = Vec>>; + +#[allow(dead_code)] +pub struct Transformer { + freqs_cis: FreqsCis, + emb: Embedding, + dim: usize, + n_layers: usize, + n_heads: usize, + n_local_heads: usize, + max_seq_len: usize, + head_dim: usize, + + norm: RMSNorm, + output: Tensor, + + layers: Vec, +} + +pub struct TransformerCaches { + layer_caches: Vec, +} + +pub struct TransformerBlock { + feed_forward: FeedForward, + attn: Attention, + ffn_norm: RMSNorm, + attention_norm: RMSNorm, +} + +pub struct AttentionCache { + cache_k: Vec>>, + cache_v: Vec>>, +} + +impl AttentionCache { + fn new(max_seq_len: usize, n_local_heads: usize, head_dim: usize) -> Self { + let mut cache_k = Vec::with_capacity(n_local_heads); + let mut cache_v = Vec::with_capacity(n_local_heads); + for _ in 0..n_local_heads { + cache_k.push(Arc::new(RwLock::new(Tensor::zeros( + head_dim as i64, + max_seq_len as i64, + TensorDType::Float32, + )))); + cache_v.push(Arc::new(RwLock::new(Tensor::zeros( + head_dim as i64, + max_seq_len as i64, + TensorDType::Float32, + )))); + } + AttentionCache { cache_k, cache_v } + } +} + +pub struct RMSNorm { + eps: f64, + weight: Tensor, +} + +pub struct Attention { + wq: Tensor, + wk: Tensor, + wv: Tensor, + wo: Tensor, + n_local_heads: usize, + head_dim: usize, +} + +pub struct FeedForward { + w1: Tensor, + w2: Tensor, + w3: Tensor, +} + +impl Transformer { + pub fn from_unpickled>( + unpickled: &unpickler::Value, + emb: Embedding, + dim: usize, + n_layers: usize, + n_heads: usize, + max_seq_len: usize, + eps: f64, + n_local_heads: usize, + head_dim: usize, + data_dir: P, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + + let progress_bar = ProgressBar::new(n_layers as u64); + let layers: Vec = (0..n_layers) + .into_par_iter() + .map(|layer_id| { + let result = TransformerBlock::from_unpickled( + unpickled, + layer_id, + eps, + n_local_heads, + head_dim, + data_dir, + ); + progress_bar.inc(1); + result + }) + .collect::, UnpicklingError>>()?; + std::mem::drop(progress_bar); + + let norm = RMSNorm::from_unpickled(unpickled, format!("norm.weight"), eps, data_dir)?; + let output = + Tensor::from_unpickled(unpickled, format!("output.weight"), data_dir)?.to_f32(); + + Ok(Transformer { + freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 2, 10000.0), + emb, + dim, + n_layers, + n_heads, + n_local_heads, + max_seq_len, + head_dim, + + norm, + output, + + layers, + }) + } + + pub fn make_caches(&self) -> TransformerCaches { + let mut result = vec![]; + for _ in 0..self.n_layers { + result.push(AttentionCache::new( + self.max_seq_len, + self.n_local_heads, + self.head_dim, + )); + } + TransformerCaches { + layer_caches: result, + } + } + + pub fn forward( + &self, + tokens: &[TokenId], + start_pos: usize, + caches: &mut TransformerCaches, + ) -> Tensor { + assert!(caches.layer_caches.len() == self.n_layers); + let mask: Option = if tokens.len() > 1 { + Some(Tensor::full_triu( + tokens.len() as i64, + tokens.len() as i64, + start_pos as i64 + 1, + TensorDType::Float32, + std::f32::NEG_INFINITY, + )) + } else { + None + }; + let mut embs: Vec<&Tensor> = Vec::with_capacity(tokens.len()); + for token in tokens.iter() { + let emb = self.emb.get_embedding(*token as usize); + embs.push(emb); + } + let mut emb_tensor: Tensor = Tensor::concat(&embs); + std::mem::drop(embs); + + for (idx, layer) in self.layers.iter().enumerate() { + emb_tensor = layer.forward( + &emb_tensor, + start_pos, + &self.freqs_cis, + &mask, + &mut caches.layer_caches[idx], + ); + } + let out = self.norm.forward(&emb_tensor); + let out = out.row(out.rows() - 1); + let prediction = self.output.matrix_mul_transposed(&out); + return prediction; + } +} + +impl TransformerBlock { + pub fn from_unpickled>( + unpickled: &unpickler::Value, + layer_id: usize, + eps: f64, + n_local_heads: usize, + head_dim: usize, + data_dir: P, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir)?; + let attn = + Attention::from_unpickled(unpickled, layer_id, n_local_heads, head_dim, data_dir)?; + let ffn_norm = RMSNorm::from_unpickled( + unpickled, + format!("layers.{}.ffn_norm.weight", layer_id), + eps, + data_dir, + )?; + let attn_norm = RMSNorm::from_unpickled( + unpickled, + format!("layers.{}.attention_norm.weight", layer_id), + eps, + data_dir, + )?; + Ok(Self { + feed_forward: ff, + attn, + ffn_norm, + attention_norm: attn_norm, + }) + } + + pub fn forward( + &self, + x: &Tensor, + start_pos: usize, + freqs_cis: &FreqsCis, + mask: &Option, + attention_cache: &mut AttentionCache, + ) -> Tensor { + let attnorm_out = self.attention_norm.forward(x); + let att_out = self + .attn + .forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache); + let h = x.add(&att_out); + let att_out = self.ffn_norm.forward(&h); + let att_out = self.feed_forward.forward(&att_out.transpose()).transpose(); + let att_out = h.add(&att_out); + return att_out; + } +} + +impl RMSNorm { + pub fn from_unpickled>( + unpickled: &unpickler::Value, + name: String, + eps: f64, + data_dir: P, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + let weights = Tensor::from_unpickled(unpickled, &name, data_dir)?.to_f32(); + Ok(Self { + eps, + weight: weights, + }) + } + + fn forward(&self, x: &Tensor) -> Tensor { + let inner = x.pow(2.0).mean_cols().add_scalar(self.eps as f32); + let out1 = x.scalar_multiply_broadcast(&inner.rsqrt()); + return out1.hadamard_product_broadcast(&self.weight); + } +} + +impl FeedForward { + pub fn from_unpickled>( + unpickled: &unpickler::Value, + layer_id: usize, + data_dir: P, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + + let w1 = Tensor::from_unpickled( + unpickled, + format!("layers.{}.feed_forward.w1.weight", layer_id), + data_dir, + )? + .to_f32(); + let w2 = Tensor::from_unpickled( + unpickled, + format!("layers.{}.feed_forward.w2.weight", layer_id), + data_dir, + )? + .to_f32(); + let w3 = Tensor::from_unpickled( + unpickled, + format!("layers.{}.feed_forward.w3.weight", layer_id), + data_dir, + )? + .to_f32(); + + Ok(Self { w1, w2, w3 }) + } + + pub fn forward(&self, x: &Tensor) -> Tensor { + let x = x.transpose(); + let (w1_out, w3_out) = rayon::join( + || self.w1.matrix_mul_transposed(&x), + || self.w3.matrix_mul_transposed(&x), + ); + let w1_out = w1_out.silu(); + let w1w3_out = w1_out.hadamard_product(&w3_out).transpose(); + let out = self.w2.matrix_mul_transposed(&w1w3_out); + return out; + } +} + +impl Attention { + pub fn from_unpickled>( + unpickled: &unpickler::Value, + layer_id: usize, + n_local_heads: usize, + head_dim: usize, + data_dir: P, + ) -> Result { + let data_dir: &Path = data_dir.as_ref(); + + let wq = Tensor::from_unpickled( + unpickled, + format!("layers.{}.attention.wq.weight", layer_id), + data_dir, + )? + .to_f32(); + let wk = Tensor::from_unpickled( + unpickled, + format!("layers.{}.attention.wk.weight", layer_id), + data_dir, + )? + .to_f32(); + let wv = Tensor::from_unpickled( + unpickled, + format!("layers.{}.attention.wv.weight", layer_id), + data_dir, + )? + .to_f32(); + let wo = Tensor::from_unpickled( + unpickled, + format!("layers.{}.attention.wo.weight", layer_id), + data_dir, + )? + .to_f32(); + + Ok(Self { + wq, + wk, + wv, + wo, + n_local_heads, + head_dim, + }) + } + + fn forward( + &self, + x: &Tensor, + start_pos: usize, + freqs_cis: &FreqsCis, + mask: &Option, + attention_cache: &mut AttentionCache, + ) -> Tensor { + let seq_len = x.rows(); + let xq_out = x.matrix_mul_transposed(&self.wq); + let xk_out = x.matrix_mul_transposed(&self.wk); + let xv_out = x.matrix_mul_transposed(&self.wv); + + let mut xq_views: Vec = Vec::with_capacity(seq_len as usize); + let mut xk_views: Vec = Vec::with_capacity(seq_len as usize); + let mut xv_views: Vec = Vec::with_capacity(seq_len as usize); + + for idx in 0..seq_len { + let xq_row = xq_out + .row(idx) + .view(self.n_local_heads as i64, self.head_dim as i64); + let xk_row = xk_out + .row(idx) + .view(self.n_local_heads as i64, self.head_dim as i64); + let xv_row = xv_out + .row(idx) + .view(self.n_local_heads as i64, self.head_dim as i64); + + let (xq_row, xk_row) = + apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos); + + xq_views.push(xq_row); + xk_views.push(xk_row); + xv_views.push(xv_row); + } + + let output: Vec = (0..self.n_local_heads) + .into_par_iter() + .map(|idx| { + let mut concat_vec: Vec = vec![]; + for idx2 in 0..seq_len { + concat_vec.push(xq_views[idx2 as usize].row(idx as i64)); + } + let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); + let xq_row = Tensor::concat(&concat_vec2); + + concat_vec.truncate(0); + for idx2 in 0..seq_len { + concat_vec.push(xk_views[idx2 as usize].row(idx as i64)); + } + let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); + let xk_row = Tensor::concat(&concat_vec2).transpose(); + + concat_vec.truncate(0); + for idx2 in 0..seq_len { + concat_vec.push(xv_views[idx2 as usize].row(idx as i64)); + } + let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); + let xv_row = Tensor::concat(&concat_vec2); + + let mut cache_k = attention_cache.cache_k[idx as usize].write().unwrap(); + let mut cache_v = attention_cache.cache_v[idx as usize].write().unwrap(); + + /* + let m = xq_row + .matrix_mul(&xk_row) + .scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt()); + //println!("mask size: {} {}", mask.rows(), mask.cols()); + //println!("m size: {} {}", m.rows(), m.cols()); + let m2 = m.add(mask).to_f32().softmax().matrix_mul(&xv_row); + m2 + println!("xk_row size: {} {}", xk_row.rows(), xk_row.cols()); + println!("xv_row size: {} {}", xv_row.rows(), xv_row.cols()); + println!("cache_k size: {} {}", cache_k.rows(), cache_k.cols()); + panic!("stop"); + */ + + for pos in start_pos..start_pos + seq_len as usize { + for dim in 0..self.head_dim { + let k = xk_row.get_f32(dim as i64, (pos - start_pos) as i64); + cache_k.set_f32(dim as i64, pos as i64, k); + let v = xv_row.get_f32((pos - start_pos) as i64, dim as i64); + cache_v.set_f32(dim as i64, pos as i64, v); + } + } + let keys = cache_k.clip_cols((start_pos + seq_len as usize) as usize); + let values = cache_v.clip_cols((start_pos + seq_len as usize) as usize); + + let m = xq_row + .matrix_mul(&keys) + .scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt()); + let m2 = match mask { + Some(ref mask) => m + .add(mask) + .to_f32() + .softmax() + .matrix_mul_transposed(&values), + None => m.softmax().matrix_mul_transposed(&values), + }; + m2 + }) + .collect(); + + // convert from 32 matrices of size 8x128 to 8 matrices of size 32x128 + // or rather 4096x1 + let output2: Vec = (0..seq_len) + .into_par_iter() + .map(|idx| { + let mut concat_vec: Vec = vec![]; + for idx2 in 0..self.n_local_heads { + concat_vec.push(output[idx2 as usize].row(idx as i64)); + } + let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); + let xq_row = Tensor::concat(&concat_vec2).view(1, 4096); + let xq_row = xq_row.matrix_mul_transposed(&self.wo); + xq_row + }) + .collect(); + let output3: Vec<&Tensor> = output2.iter().collect(); + let output2: Tensor = Tensor::concat(&output3); + return output2; + } +} + +fn apply_rotary_emb( + xq: &Tensor, + xk: &Tensor, + freqs_cis: &FreqsCis, + seq_idx: usize, + start_pos: usize, +) -> (Tensor, Tensor) { + assert!(xq.cols() % 2 == 0); + assert!(xk.cols() % 2 == 0); + let mut xq_out: Tensor = xq.clone(); + let mut xk_out: Tensor = xk.clone(); + for row in 0..xq.rows() { + for col in 0..xq.cols() / 2 { + let f_real = freqs_cis[seq_idx + start_pos][col as usize].re as f32; + let f_imag = freqs_cis[seq_idx + start_pos][col as usize].im as f32; + let xq_real = xq.get_f32(row, col * 2); + let xq_imag = xq.get_f32(row, col * 2 + 1); + let xk_real = xk.get_f32(row, col * 2); + let xk_imag = xk.get_f32(row, col * 2 + 1); + + // multiply with freqs_cis + let xq_realpart = xq_real * f_real - xq_imag * f_imag; + let xq_imagpart = xq_real * f_imag + xq_imag * f_real; + let xk_realpart = xk_real * f_real - xk_imag * f_imag; + let xk_imagpart = xk_real * f_imag + xk_imag * f_real; + + xq_out.set_f32(row, col * 2, xq_realpart); + xq_out.set_f32(row, col * 2 + 1, xq_imagpart); + xk_out.set_f32(row, col * 2, xk_realpart); + xk_out.set_f32(row, col * 2 + 1, xk_imagpart); + } + } + return (xq_out, xk_out); +} + +fn compute_freqs_cis(dim: usize, end: usize, theta: f64) -> FreqsCis { + let mut freqs = Vec::new(); + for idx in 0..(dim / 2) { + let freq = 1.0 / (theta.powf(idx as f64 * 2.0 / dim as f64)); + freqs.push(freq); + } + + let mut result: Vec> = Vec::new(); + for x in 0..end { + let mut row = Vec::new(); + for y in 0..freqs.len() { + let freq = freqs[y] * (x as f64); + row.push(freq); + } + result.push(row); + } + + let mut resultc: Vec>> = Vec::new(); + for row in result.into_iter() { + let mut rowc = Vec::new(); + for freq in row { + let cis = Complex::from_polar(1.0, freq); + rowc.push(cis); + } + resultc.push(rowc); + } + resultc +} diff --git a/src/unpickler.rs b/src/unpickler.rs new file mode 100644 index 0000000..f4744c2 --- /dev/null +++ b/src/unpickler.rs @@ -0,0 +1,626 @@ +use std::collections::BTreeMap; +use std::path::PathBuf; + +pub struct Unpickler {} + +use crate::tensor::{TensorBuilder, TensorDType, TensorError}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum UnpicklingError { + #[error("Unpickling error: {0}")] + UnpicklingError(String), + #[error("UTF-8 decoding error")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("Missing field")] + MissingField(String), + #[error("Tensor conversion operation failed")] + TensorError(#[from] TensorError), + #[error("Data has incorrect format to be converted to a tensor")] + InvalidTensorData, +} + +#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub enum Value { + Mark(usize), + String(String), + Global(String, String), // module name, attribute name + Integer64(i64), + Tuple(Vec), + PersistentId(Box), + Bool(bool), + Reduce(Box, Box), + Dict(BTreeMap), +} + +impl Value { + // Gets a value from a dictionary, assuming Value is a dictionary. + // + // Returns None if the key is not found, or the value is not a dictionary. + pub fn get(&self, key: &Value) -> Option<&Value> { + match self { + Value::Dict(d) => d.get(key), + _ => None, + } + } + + // Same as get() but uses a string as key. + pub fn get_str_key>(&self, key: S) -> Option<&Value> { + self.get(&Value::String(key.as_ref().to_string())) + } + + pub fn get_global(&self) -> Option<(&str, &str)> { + match self { + Value::Global(module_name, attribute_name) => Some((module_name, attribute_name)), + _ => None, + } + } + + pub fn get_str(&self) -> Option<&str> { + match self { + Value::String(s) => Some(s), + _ => None, + } + } + + pub fn get_int64(&self) -> Option { + match self { + Value::Integer64(i) => Some(*i), + _ => None, + } + } + + pub fn get_persistent_id(&self) -> Option<&Value> { + match self { + Value::PersistentId(v) => Some(&v), + _ => None, + } + } + + pub fn get_tuple(&self) -> Option<&[Value]> { + match self { + Value::Tuple(v) => Some(&v), + _ => None, + } + } + + // Assume that the value represents a tensor in PyTorch and return instructions how to actually + // load the values. + pub fn to_tensor_builder(&self) -> Option { + match self { + Value::Reduce(call, args) => match **call { + Value::Global(ref module_name, ref attribute_name) => { + if module_name == "torch._utils" && attribute_name == "_rebuild_tensor_v2" { + match **args { + Value::Tuple(ref args) => self.to_tensor_builder2(&args), + _ => None, + } + } else { + None + } + } + _ => None, + }, + _ => None, + } + } + + fn to_tensor_builder2(&self, args: &[Value]) -> Option { + if args.len() == 6 { + Self::to_tensor_builder2_6items(args) + } else if args.len() == 4 { + Self::to_tensor_builder2_4items(args) + } else { + None + } + } + + fn to_tensor_builder2_4items(args: &[Value]) -> Option { + let storagev: &Value = args[0].get_persistent_id()?; + let storage_args: &[Value] = storagev.get_tuple()?; + let storage_mark: &str = storage_args[0].get_str()?; + if storage_mark != "storage" { + return None; + } + + let (storage_module, storage_type) = storage_args[1].get_global()?; + if storage_module != "torch" { + return None; + } + let dtype: TensorDType = match storage_type { + "HalfStorage" => TensorDType::Float16, + _ => return None, + }; + let storage_filename: &str = storage_args[2].get_str()?; + let nitems: i64 = storage_args[4].get_int64()?; + + let offset: i64 = args[1].get_int64()?; + if offset != 0 { + return None; + } + + let rows: i64 = 1; + let cols: i64 = nitems; + let row_stride: i64 = cols; + if row_stride != cols { + return None; + } + + return Some(TensorBuilder { + src_path: PathBuf::from(storage_filename), + dtype, + stride: row_stride, + rows, + cols, + nitems, + }); + } + + fn to_tensor_builder2_6items(args: &[Value]) -> Option { + let storagev: &Value = args[0].get_persistent_id()?; + let storage_args: &[Value] = storagev.get_tuple()?; + let storage_mark: &str = storage_args[0].get_str()?; + if storage_mark != "storage" { + return None; + } + + let (storage_module, storage_type) = storage_args[1].get_global()?; + if storage_module != "torch" { + return None; + } + let dtype: TensorDType = match storage_type { + "HalfStorage" => TensorDType::Float16, + _ => return None, + }; + let storage_filename: &str = storage_args[2].get_str()?; + let nitems: i64 = storage_args[4].get_int64()?; + + let offset: i64 = args[1].get_int64()?; + if offset != 0 { + return None; + } + + let shape: &[Value] = args[2].get_tuple()?; + let stride: &[Value] = args[3].get_tuple()?; + + if shape.len() != 2 { + return None; + } + if stride.len() != 2 { + return None; + } + + let rows: i64 = shape[0].get_int64()?; + let cols: i64 = shape[1].get_int64()?; + + let row_stride: i64 = stride[0].get_int64()?; + let col_stride: i64 = stride[1].get_int64()?; + + if col_stride != 1 { + return None; + } + if row_stride != cols { + return None; + } + + return Some(TensorBuilder { + src_path: PathBuf::from(storage_filename), + dtype, + stride: row_stride, + rows, + cols, + nitems, + }); + + /* Args should look like this (took random example from debug print) : + 0 PERSISTENT_ID + TUPLE + STRING "storage" + GLOBAL "torch" "HalfStorage" + STRING "0" (filename) + STRING "cpu" + INTEGER 131072000 (number of items) + 1 INTEGER 0 + 2 TUPLE + INTEGER 32000 + INTEGER 4096 + 3 TUPLE + INTEGER 4096 + INTEGER 1 + 4 BOOL false (this is about gradient) + 5 REDUCE (no idea why this is here) + GLOBAL "collections" "OrderedDict" + TUPLE + + Sometimes arguments 2 and 3 are missing. + */ + } + + // Print a nice representation of the value to stdout. Used for good old printf debugging. + pub fn debug_print(&self) { + self.debug_print_go(0); + } + + fn debug_print_go(&self, indent: usize) { + if indent > 0 { + print!("{:indent$}", "", indent = indent); + } + match self { + Value::Mark(_) => { + println!("MARK"); + } + Value::String(s) => { + println!("STRING {:?}", s); + } + Value::Global(module_name, attribute_name) => { + println!("GLOBAL {:?} {:?}", module_name, attribute_name); + } + Value::Integer64(i) => { + println!("INTEGER {:?}", i); + } + Value::Tuple(v) => { + println!("TUPLE"); + for i in v { + i.debug_print_go(indent + 2); + } + } + Value::PersistentId(v) => { + println!("PERSISTENT_ID"); + v.debug_print_go(indent + 2); + } + Value::Bool(b) => { + println!("BOOL {:?}", b); + } + Value::Reduce(v1, v2) => { + println!("REDUCE"); + v1.debug_print_go(indent + 2); + v2.debug_print_go(indent + 2); + } + Value::Dict(d) => { + println!("DICT"); + for (k, v) in d { + k.debug_print_go(indent + 2); + v.debug_print_go(indent + 2); + } + } + } + } +} + +pub fn unpickle(bytes: &[u8]) -> Result { + // The LLaMA file is in pickle 2 format, check that header is there + if bytes.len() < 2 { + return Err(UnpicklingError::UnpicklingError( + "Data is too short to be a pickle".to_string(), + )); + } + + if bytes[0] != 128 || bytes[1] != 2 { + return Err(UnpicklingError::UnpicklingError( + "No magic header using Pickle 2 protocol".to_string(), + )); + } + + let mut memo: BTreeMap = BTreeMap::new(); + let mut stack: Vec = vec![]; + + // Decode frames + let mut bytes: &[u8] = &bytes[2..]; + while !bytes.is_empty() { + let frame_opcode = bytes[0]; + if frame_opcode == 125 { + // empty dict + stack.push(Value::Dict(BTreeMap::new())); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 113 { + // binput + if bytes.len() < 2 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling BINPUT".to_string(), + )); + } + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling BINPUT".to_string(), + )); + } + let key = bytes[1]; + memo.insert(key as u32, stack.last().unwrap().clone()); + bytes = &bytes[2..]; + continue; + } + if frame_opcode == 40 { + // mark + stack.push(Value::Mark(stack.len())); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 88 { + // binunicode + if bytes.len() < 5 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling BINUNICODE".to_string(), + )); + } + let len = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if bytes.len() < 5 + len as usize { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling BINUNICODE".to_string(), + )); + } + let string = std::str::from_utf8(&bytes[5..5 + len as usize])?; + stack.push(Value::String(string.to_string())); + bytes = &bytes[5 + len as usize..]; + continue; + } + if frame_opcode == 99 { + // global + // followed by newline terminated module name and attribute name + bytes = &bytes[1..]; + let mut module_name = String::new(); + while !bytes.is_empty() && bytes[0] != 10 { + module_name.push(bytes[0] as char); + bytes = &bytes[1..]; + if bytes.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling GLOBAL".to_string(), + )); + } + } + bytes = &bytes[1..]; + let mut attribute_name = String::new(); + while !bytes.is_empty() && bytes[0] != 10 { + attribute_name.push(bytes[0] as char); + bytes = &bytes[1..]; + if bytes.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling GLOBAL".to_string(), + )); + } + } + bytes = &bytes[1..]; + stack.push(Value::Global(module_name, attribute_name)); + continue; + } + if frame_opcode == 74 { + // binint + if bytes.len() < 5 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling BININT".to_string(), + )); + } + let value = i32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + stack.push(Value::Integer64(value as i64)); + bytes = &bytes[5..]; + continue; + } + if frame_opcode == 116 { + // tuple + let mut tuple = vec![]; + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling TUPLE".to_string(), + )); + } + let mut ok = false; + while !stack.is_empty() { + let top = stack.pop().unwrap(); + if let Value::Mark(_mark) = top { + tuple.reverse(); + stack.push(Value::Tuple(tuple)); + ok = true; + break; + } + tuple.push(top); + } + if !ok { + return Err(UnpicklingError::UnpicklingError( + "No mark while handling TUPLE".to_string(), + )); + } + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 81 { + // binpersid + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling BINPERSID".to_string(), + )); + } + let top = stack.pop().unwrap(); + stack.push(Value::PersistentId(Box::new(top))); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 75 { + // binint1 + if bytes.len() < 2 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling BININT1".to_string(), + )); + } + let value = bytes[1]; + stack.push(Value::Integer64(value as i64)); + bytes = &bytes[2..]; + continue; + } + if frame_opcode == 77 { + // binint2 + if bytes.len() < 3 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling BININT2".to_string(), + )); + } + let value = i16::from_le_bytes([bytes[1], bytes[2]]); + stack.push(Value::Integer64(value as i64)); + bytes = &bytes[3..]; + continue; + } + if frame_opcode == 134 { + // tuple2 + let mut tuple = vec![]; + if stack.len() < 2 { + return Err(UnpicklingError::UnpicklingError( + "Stack does not have enough items while handling TUPLE2".to_string(), + )); + } + tuple.push(stack.pop().unwrap()); + tuple.push(stack.pop().unwrap()); + tuple.reverse(); + stack.push(Value::Tuple(tuple)); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 137 { + // newfalse + stack.push(Value::Bool(false)); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 41 { + // empty tuple + stack.push(Value::Tuple(vec![])); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 82 { + // reduce + if stack.len() < 2 { + return Err(UnpicklingError::UnpicklingError( + "Stack does not have enough items while handling REDUCE".to_string(), + )); + } + let arg_tuple = stack.pop().unwrap(); + let callable = stack.pop().unwrap(); + stack.push(Value::Reduce(Box::new(callable), Box::new(arg_tuple))); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 104 { + // binget + if bytes.len() < 2 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling BINGET".to_string(), + )); + } + let idx = bytes[1]; + match memo.get(&(idx as u32)) { + None => { + return Err(UnpicklingError::UnpicklingError( + "BINGET index out of range".to_string(), + )); + } + Some(memo_value) => { + stack.push(memo_value.clone()); + } + } + bytes = &bytes[2..]; + continue; + } + if frame_opcode == 133 { + // tuple1 + let mut tuple = vec![]; + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling TUPLE1".to_string(), + )); + } + tuple.push(stack.pop().unwrap()); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 114 { + // long binput + if bytes.len() < 5 { + return Err(UnpicklingError::UnpicklingError( + "Unexpected end of data while handling LONG_BINPUT".to_string(), + )); + } + let key = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling LONG_BINPUT".to_string(), + )); + } + memo.insert(key as u32, stack.last().unwrap().clone()); + bytes = &bytes[5..]; + continue; + } + if frame_opcode == 117 { + // setitems + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling SETITEMS".to_string(), + )); + } + let mut ok = false; + let mut keyvalues: BTreeMap = BTreeMap::new(); + while !stack.is_empty() { + let value = stack.pop().unwrap(); + if let Value::Mark(_mark) = value { + ok = true; + break; + } + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling SETITEMS".to_string(), + )); + } + let key = stack.pop().unwrap(); + if let Value::Mark(_mark) = key { + return Err(UnpicklingError::UnpicklingError( + "Unexpected mark while handling SETITEMS".to_string(), + )); + } + keyvalues.insert(key, value); + } + if !ok { + return Err(UnpicklingError::UnpicklingError( + "No mark while handling SETITEMS".to_string(), + )); + } + if stack.is_empty() { + return Err(UnpicklingError::UnpicklingError( + "Stack is empty while handling SETITEMS".to_string(), + )); + } + let mut dict = stack.pop().unwrap(); + match dict { + Value::Dict(ref mut dict) => { + for (key, value) in keyvalues { + dict.insert(key, value); + } + } + _ => { + return Err(UnpicklingError::UnpicklingError( + "SETITEMS on non-dict".to_string(), + )); + } + } + stack.push(dict); + bytes = &bytes[1..]; + continue; + } + if frame_opcode == 46 { + // stop + // bytes = &bytes[1..]; + break; + } + return Err(UnpicklingError::UnpicklingError(format!( + "Unknown opcode: {}", + frame_opcode + ))); + } + + // Stack should have just one item, our final value + if stack.len() != 1 { + return Err(UnpicklingError::UnpicklingError( + "Stack does not have exactly one item after unpickling".to_string(), + )); + } + + Ok(stack.pop().unwrap()) +}