First commit. LLaMA works now. It is not pretty but it does generate text from prompts. Yay.
commit
3b8f904f13
@ -0,0 +1 @@
|
||||
/target
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,39 @@
|
||||
[package]
|
||||
name = "rllama"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "rllama"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
protobuf = "3.2"
|
||||
thiserror = "1.0"
|
||||
half = "2.2"
|
||||
num-complex = "0.4"
|
||||
embedded-profiling = "0.3"
|
||||
rand = "0.8"
|
||||
approx = "0.5"
|
||||
rayon = "1.7"
|
||||
clap = { version = "4.1", features = ["derive"] }
|
||||
indicatif = "0.17"
|
||||
|
||||
# We need protobuf compiler
|
||||
[build-dependencies]
|
||||
protobuf-codegen = "3.2"
|
||||
protobuf-parse = "3.2"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.4"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
[[bench]]
|
||||
path = "src/benches/benchmark.rs"
|
||||
name = "benchmark"
|
||||
harness = false
|
||||
@ -0,0 +1,662 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
@ -0,0 +1,208 @@
|
||||
proto/ directory contains a protobuf file from Google's
|
||||
https://github.com/google/sentencepiece repository.
|
||||
|
||||
Here is their license: (note rllama as a whole is AGPL3)
|
||||
-----
|
||||
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@ -0,0 +1,21 @@
|
||||
# AdeonLLaMA
|
||||
|
||||
This is my attempt at making the LLaMA language model working on a pure Rust
|
||||
CPU implementation.
|
||||
|
||||
As of writing of this, it can run LLaMA-7B at around ~1 token per second, using
|
||||
something like 1.5 threads because I haven't yet properly figured out how to
|
||||
multithread this.
|
||||
|
||||
It uses AVX2 intrinsics to speed up itself.
|
||||
|
||||
# How to run
|
||||
|
||||
You will need the LLaMA-7B weights first. Refer to https://github.com/facebookresearch/llama/
|
||||
|
||||
Once you have 7B weights, and the `tokenizer.model` it comes with, you can make
|
||||
it generate tokens:
|
||||
|
||||
```shell
|
||||
cargo run --release -- --tokenizer-model /path/to/tokenizer.model --model-path /path/to/LLaMA/7B
|
||||
```
|
||||
@ -0,0 +1,9 @@
|
||||
fn main() {
|
||||
protobuf_codegen::Codegen::new()
|
||||
.pure()
|
||||
.out_dir("src/protomodels")
|
||||
.include("proto")
|
||||
.input("proto/sentencepiece_model.proto")
|
||||
.run()
|
||||
.unwrap();
|
||||
}
|
||||
@ -0,0 +1,321 @@
|
||||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
// TODO(taku): Needs to use LITE RUNTIME in OSS release.
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
||||
package sentencepiece;
|
||||
|
||||
// TrainerSpec encodes a various parameters for SentencePiece training.
|
||||
// Next id: 53
|
||||
message TrainerSpec {
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// General parameters
|
||||
//
|
||||
// Input corpus files.
|
||||
// Trainer accepts the following two formats:
|
||||
// A) Monolingual: plain text, one sentence per line.
|
||||
// B) Bilingual: TSV, source sentence <tab> target sentence
|
||||
// When bilingual data is passed, shared vocabulary model is built.
|
||||
// Note that the input file must be raw corpus, not a preprocessed corpus.
|
||||
// Trainer only loads the first `input_sentence_size` sentences specified
|
||||
// with this parameter.
|
||||
repeated string input = 1;
|
||||
|
||||
// Input corpus format:
|
||||
// "text": one-sentence-per-line text format (default)
|
||||
// "tsv": sentence <tab> freq
|
||||
optional string input_format = 7;
|
||||
|
||||
// Output model file prefix.
|
||||
// <model_prefix>.model and <model_prefix>.vocab are generated.
|
||||
optional string model_prefix = 2;
|
||||
|
||||
// Model type. only have UNIGRAM now.
|
||||
enum ModelType {
|
||||
UNIGRAM = 1; // Unigram language model with dynamic algorithm
|
||||
BPE = 2; // Byte Pair Encoding
|
||||
WORD = 3; // Delimitered by whitespace.
|
||||
CHAR = 4; // tokenizes into character sequence
|
||||
}
|
||||
optional ModelType model_type = 3 [default = UNIGRAM];
|
||||
|
||||
// Vocabulary size. 8k is the default size.
|
||||
optional int32 vocab_size = 4 [default = 8000];
|
||||
|
||||
// List of the languages this model can accept.
|
||||
// Since the model is language-agnostic, this field is used as a reference.
|
||||
repeated string accept_language = 5;
|
||||
|
||||
// Size of self-test samples, which are encoded in the model file.
|
||||
optional int32 self_test_sample_size = 6 [default = 0];
|
||||
|
||||
// Whether to use DP version of sentencepiece. Use it with TSV input format
|
||||
// (requires precomputed word tab counts to work).
|
||||
optional bool enable_differential_privacy = 50 [default = false];
|
||||
// Set these parameters if you need DP version of sentencepiece.
|
||||
// std of noise to add.
|
||||
optional float differential_privacy_noise_level = 51 [default = 0.0];
|
||||
// Clipping threshold to apply after adding noise. All the words with
|
||||
// frequency less than this value are dropped.
|
||||
optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// Training parameters.
|
||||
//
|
||||
// Uses characters which cover the corpus with the ratio of `chars_coverage`.
|
||||
// This parameter determines the set of basic Alphabet of sentence piece.
|
||||
// 1.0 - `chars_coverage` characters are treated as UNK.
|
||||
// See also required_chars field.
|
||||
optional float character_coverage = 10 [default = 0.9995];
|
||||
|
||||
// Maximum size of sentences the trainer loads from `input` parameter.
|
||||
// Trainer simply loads the `input` files in sequence.
|
||||
// It is better to shuffle the input corpus randomly.
|
||||
optional uint64 input_sentence_size = 11 [default = 0];
|
||||
optional bool shuffle_input_sentence = 19 [default = true];
|
||||
|
||||
// Maximum size of sentences to make seed sentence pieces.
|
||||
// Extended suffix array is constructed to extract frequent
|
||||
// sub-strings from the corpus. This uses 20N working space,
|
||||
// where N is the size of corpus.
|
||||
optional int32 mining_sentence_size = 12 [deprecated = true];
|
||||
|
||||
// Maximum size of sentences to train sentence pieces.
|
||||
optional int32 training_sentence_size = 13 [deprecated = true];
|
||||
|
||||
// The size of seed sentencepieces.
|
||||
// `seed_sentencepiece_size` must be larger than `vocab_size`.
|
||||
optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
||||
|
||||
// In every EM sub-iterations, keeps top
|
||||
// `shrinking_factor` * `current sentencepieces size` with respect to
|
||||
// the loss of the sentence piece. This value should be smaller than 1.0.
|
||||
optional float shrinking_factor = 15 [default = 0.75];
|
||||
|
||||
// The maximum sentence length in byte. The sentences with the length
|
||||
// larger than `max_sentence_length` is simply ignored.
|
||||
// Longer input tends to bring the following risks:
|
||||
// * Overflow during EM training (unigram language model only)
|
||||
// * Performance drop because of O(n log n) cost in BPE.
|
||||
optional int32 max_sentence_length = 18 [default = 4192];
|
||||
|
||||
// Number of threads in the training.
|
||||
optional int32 num_threads = 16 [default = 16];
|
||||
|
||||
// Number of EM sub iterations.
|
||||
optional int32 num_sub_iterations = 17 [default = 2];
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// SentencePiece parameters which control the shapes of sentence piece.
|
||||
//
|
||||
// Maximum length of sentencepiece.
|
||||
optional int32 max_sentencepiece_length = 20 [default = 16];
|
||||
|
||||
// Uses Unicode script to split sentence pieces.
|
||||
// When `split_by_unicode_script` is true, we do not allow sentence piece to
|
||||
// include multiple Unicode scripts, e.g. "F1" is not a valid piece.
|
||||
// Exception: CJ characters (Hiragana/Katakana/Han) are all handled
|
||||
// as one script type, since Japanese word can consist of multiple scripts.
|
||||
// This exception is always applied regardless of the accept-language
|
||||
// parameter.
|
||||
optional bool split_by_unicode_script = 21 [default = true];
|
||||
|
||||
// When `split_by_number` is true, put a boundary between number and
|
||||
// non-number transition. If we want to treat "F1" is one token, set this flag
|
||||
// to be false.
|
||||
optional bool split_by_number = 23 [default = true];
|
||||
|
||||
// Use a white space to split sentence pieces.
|
||||
// When `split_by_whitespace` is false, we may have the piece containing
|
||||
// a white space in the middle. e.g., "in_the".
|
||||
optional bool split_by_whitespace = 22 [default = true];
|
||||
|
||||
// Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello =>
|
||||
// hello_. When `treat_whitespace_as_suffix` is true,
|
||||
// NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end
|
||||
// of sentence.
|
||||
optional bool treat_whitespace_as_suffix = 24 [default = false];
|
||||
|
||||
// Allows pieces that only contain whitespaces instead of appearing only as
|
||||
// prefix or suffix of other pieces.
|
||||
optional bool allow_whitespace_only_pieces = 26 [default = false];
|
||||
|
||||
// Split all digits (0-9) into separate pieces.
|
||||
optional bool split_digits = 25 [default = false];
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// Vocabulary management
|
||||
//
|
||||
// Defines control symbols used as an indicator to
|
||||
// change the behavior of the decoder. <s> and </s> are pre-defined.
|
||||
// We can use this field to encode various meta information,
|
||||
// including language indicator in multilingual model.
|
||||
// These symbols are not visible to users, but visible to
|
||||
// the decoder. Note that when the input sentence contains control symbols,
|
||||
// they are not treated as one token, but segmented into normal pieces.
|
||||
// Control symbols must be inserted independently from the segmentation.
|
||||
repeated string control_symbols = 30;
|
||||
|
||||
// Defines user defined symbols.
|
||||
// These symbols are added with extremely high score
|
||||
// so they are always treated as one unique symbol in any context.
|
||||
// Typical usage of user_defined_symbols is placeholder for named entities.
|
||||
repeated string user_defined_symbols = 31;
|
||||
|
||||
// Defines required characters. Each UTF8 character in this string is included
|
||||
// in the character set regardless of character_coverage value. Unlike
|
||||
// user_defined_symbols, these characters have scores based on the frequency
|
||||
// on input sentences, and the model can form subwords using characters
|
||||
// in this field.
|
||||
optional string required_chars = 36;
|
||||
|
||||
// Decomposes unknown pieces into UTF-8 bytes.
|
||||
optional bool byte_fallback = 35 [default = false];
|
||||
|
||||
// When creating the vocabulary file, defines whether or not to additionally
|
||||
// output the score for each piece.
|
||||
optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||
|
||||
// `vocab_size` is treated as hard limit. Crash if
|
||||
// the model can not produce the vocab of size `vocab_size`,
|
||||
// When `hard_vocab_limit` is false, vocab_size is treated
|
||||
// as soft limit. Note that when model_type=char,
|
||||
// always assumes hard_vocab_limit = false.
|
||||
optional bool hard_vocab_limit = 33 [default = true];
|
||||
|
||||
// use all symbols for vocab extraction. This flag is valid
|
||||
// if model type is either CHAR or WORD
|
||||
optional bool use_all_vocab = 34 [default = false];
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// Reserved special meta tokens.
|
||||
// * -1 is not used.
|
||||
// * unk_id must not be -1.
|
||||
// Id must starts with 0 and be contigous.
|
||||
optional int32 unk_id = 40 [default = 0]; // <unk>
|
||||
optional int32 bos_id = 41 [default = 1]; // <s>
|
||||
optional int32 eos_id = 42 [default = 2]; // </s>
|
||||
optional int32 pad_id = 43 [default = -1]; // <pad> (padding)
|
||||
optional string unk_piece = 45 [default = "<unk>"];
|
||||
optional string bos_piece = 46 [default = "<s>"];
|
||||
optional string eos_piece = 47 [default = "</s>"];
|
||||
optional string pad_piece = 48 [default = "<pad>"];
|
||||
|
||||
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
|
||||
// since this character can be useful both for user and
|
||||
// developer. We can easily figure out that <unk> is emitted.
|
||||
optional string unk_surface = 44 [default = " \xE2\x81\x87 "];
|
||||
|
||||
// Increase bit depth to allow unigram model training on large
|
||||
// (>10M sentences) corpora. A Side-effect of enabling this flag
|
||||
// is increased memory usage.
|
||||
optional bool train_extremely_large_corpus = 49 [default = false];
|
||||
|
||||
// Customized extensions: the range of field numbers
|
||||
// are open to third-party extensions.
|
||||
extensions 200 to max;
|
||||
}
|
||||
|
||||
// NormalizerSpec encodes a various parameters for string normalizaiton
|
||||
message NormalizerSpec {
|
||||
// name of normalization rule.
|
||||
optional string name = 1;
|
||||
|
||||
// Pre-compiled normalization rule created by
|
||||
// Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method.
|
||||
// Usually this field is set by Builder::GetNormalizerSpec() method.
|
||||
optional bytes precompiled_charsmap = 2;
|
||||
|
||||
// Adds dummy whitespace at the beginning of text in order to
|
||||
// treat "world" in "world" and "hello world" in the same way.
|
||||
optional bool add_dummy_prefix = 3 [default = true];
|
||||
|
||||
// Removes leading, trailing, and duplicate internal whitespace.
|
||||
optional bool remove_extra_whitespaces = 4 [default = true];
|
||||
|
||||
// Replaces whitespace with meta symbol.
|
||||
// This field must be true to train sentence piece model.
|
||||
optional bool escape_whitespaces = 5 [default = true];
|
||||
|
||||
// Custom normalization rule file in TSV format.
|
||||
// https://github.com/google/sentencepiece/blob/master/doc/normalization.md
|
||||
// This field is only used in SentencePieceTrainer::Train() method, which
|
||||
// compiles the rule into the binary rule stored in `precompiled_charsmap`.
|
||||
optional string normalization_rule_tsv = 6;
|
||||
|
||||
// Customized extensions: the range of field numbers
|
||||
// are open to third-party extensions.
|
||||
extensions 200 to max;
|
||||
}
|
||||
|
||||
// Proto to store samples for self-testing.
|
||||
message SelfTestData {
|
||||
message Sample {
|
||||
optional string input = 1;
|
||||
optional string expected = 2;
|
||||
}
|
||||
repeated Sample samples = 1;
|
||||
|
||||
// Customized extensions: the range of field numbers
|
||||
// are open to third-party extensions.
|
||||
extensions 200 to max;
|
||||
}
|
||||
|
||||
// ModelProto stores model parameters.
|
||||
// SentencePieceProcessor is supposed to be self-contained.
|
||||
// All settings/parameters which may change the behavior must be encoded
|
||||
// in ModelProto.
|
||||
message ModelProto {
|
||||
message SentencePiece {
|
||||
enum Type {
|
||||
NORMAL = 1; // normal symbol
|
||||
UNKNOWN = 2; // unknown symbol. only <unk> for now.
|
||||
CONTROL = 3; // control symbols. </s>, <s>, <2ja> etc.
|
||||
USER_DEFINED = 4; // user defined symbols.
|
||||
// Typical usage of USER_DEFINED symbol
|
||||
// is placeholder.
|
||||
BYTE = 6; // byte symbols. Used when `byte_fallback` is true.
|
||||
UNUSED = 5; // this piece is not used.
|
||||
}
|
||||
optional string piece = 1; // piece must not be empty.
|
||||
optional float score = 2;
|
||||
optional Type type = 3 [default = NORMAL];
|
||||
|
||||
// Customized extensions: the range of field numbers
|
||||
// are open to third-party extensions.
|
||||
extensions 200 to max;
|
||||
}
|
||||
|
||||
// Sentence pieces with scores.
|
||||
repeated SentencePiece pieces = 1;
|
||||
|
||||
// Spec used to generate this model file.
|
||||
optional TrainerSpec trainer_spec = 2;
|
||||
|
||||
// Spec for text normalization.
|
||||
optional NormalizerSpec normalizer_spec = 3;
|
||||
|
||||
// Stores sample input and its expected segmentation to verify the model.
|
||||
optional SelfTestData self_test_data = 4;
|
||||
|
||||
// Spec for text de-normalization.
|
||||
optional NormalizerSpec denormalizer_spec = 5;
|
||||
|
||||
// Customized extensions: the range of field numbers
|
||||
// are open to third-party extensions.
|
||||
extensions 200 to max;
|
||||
}
|
||||
@ -0,0 +1,85 @@
|
||||
extern crate rllama;
|
||||
|
||||
use rllama::tensor::{Tensor, TensorDType};
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
pub fn tensor_benchmarks(c: &mut Criterion) {
|
||||
let orig16_1 = Tensor::full(16, 32, TensorDType::Float16, 3.0);
|
||||
let orig16_2 = Tensor::full(32, 512, TensorDType::Float16, -1.33);
|
||||
|
||||
let orig32_1 = Tensor::full(16, 32, TensorDType::Float32, 3.0);
|
||||
let orig32_2 = Tensor::full(32, 512, TensorDType::Float32, -1.33);
|
||||
let orig32_2_transposed = orig32_2.transpose();
|
||||
|
||||
let mut result_16 = Tensor::zeros(16, 512, TensorDType::Float16);
|
||||
let mut result_32 = Tensor::zeros(16, 512, TensorDType::Float32);
|
||||
|
||||
let orig_84096_1 = Tensor::zeros(8, 4096, TensorDType::Float32);
|
||||
let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32);
|
||||
let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32);
|
||||
|
||||
c.bench_function(
|
||||
"matrix multiplication 8x4096 @ 4096x4096 f32 in-place",
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let _ = result_84096
|
||||
.matrix_mul_inplace(black_box(&orig_84096_1), black_box(&orig_84096_2));
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
"matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed",
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let _ = result_84096.matrix_mul_inplace_transposed(
|
||||
black_box(&orig_84096_1),
|
||||
black_box(&orig_84096_2),
|
||||
);
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function("matrix multiplication f32 not in-place", |b| {
|
||||
b.iter(|| {
|
||||
let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2));
|
||||
})
|
||||
});
|
||||
c.bench_function("matrix multiplication f32 naive", |b| {
|
||||
b.iter(|| {
|
||||
let _ = black_box(&orig32_1).matrix_mul_naive(black_box(&orig32_2));
|
||||
})
|
||||
});
|
||||
c.bench_function("matrix multiplication f16 not in-place", |b| {
|
||||
b.iter(|| {
|
||||
let _ = black_box(&orig16_1).matrix_mul(black_box(&orig16_2));
|
||||
})
|
||||
});
|
||||
c.bench_function("matrix multiplication f16 naive", |b| {
|
||||
b.iter(|| {
|
||||
let _ = black_box(&orig16_1).matrix_mul_naive(black_box(&orig16_2));
|
||||
})
|
||||
});
|
||||
c.bench_function("matrix multiplication f16 in-place", |b| {
|
||||
b.iter(|| {
|
||||
let _ = result_16.matrix_mul_inplace(black_box(&orig16_1), black_box(&orig16_2));
|
||||
})
|
||||
});
|
||||
c.bench_function("matrix multiplication f32 in-place", |b| {
|
||||
b.iter(|| {
|
||||
let _ = result_32.matrix_mul_inplace(black_box(&orig32_1), black_box(&orig32_2));
|
||||
})
|
||||
});
|
||||
c.bench_function("matrix multiplication f32 in-place, transposed", |b| {
|
||||
b.iter(|| {
|
||||
let _ = result_32.matrix_mul_inplace_transposed(
|
||||
black_box(&orig32_1),
|
||||
black_box(&orig32_2_transposed),
|
||||
);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, tensor_benchmarks);
|
||||
criterion_main!(benches);
|
||||
@ -0,0 +1,45 @@
|
||||
use crate::tensor::Tensor;
|
||||
use crate::unpickler;
|
||||
use crate::unpickler::*;
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct Embedding {
|
||||
wgts: BTreeMap<usize, Tensor>,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn from_unpickled<P: AsRef<Path>>(
|
||||
unpickled: &unpickler::Value,
|
||||
data_dir: P,
|
||||
) -> Result<Self, UnpicklingError> {
|
||||
let data_dir: &Path = data_dir.as_ref();
|
||||
|
||||
let val = match unpickled.get_str_key("tok_embeddings.weight") {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
return Err(UnpicklingError::MissingField(
|
||||
"tok_embeddings.weight".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
let tensor = val
|
||||
.to_tensor_builder()
|
||||
.ok_or(UnpicklingError::InvalidTensorData)?;
|
||||
let tensor = tensor.load(data_dir)?;
|
||||
|
||||
let num_embeddings = tensor.rows();
|
||||
|
||||
let mut table: BTreeMap<usize, Tensor> = BTreeMap::new();
|
||||
for key in 0..num_embeddings {
|
||||
let row = tensor.row(key);
|
||||
table.insert(key as usize, row);
|
||||
}
|
||||
|
||||
Ok(Self { wgts: table })
|
||||
}
|
||||
|
||||
pub fn get_embedding(&self, idx: usize) -> &Tensor {
|
||||
self.wgts.get(&idx).unwrap()
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,10 @@
|
||||
#![feature(stdsimd)]
|
||||
|
||||
pub mod embedding;
|
||||
pub mod protomodels;
|
||||
pub mod rllama_main;
|
||||
pub mod tensor;
|
||||
pub mod token_sampler;
|
||||
pub mod tokenizer;
|
||||
pub mod transformer;
|
||||
pub mod unpickler;
|
||||
@ -0,0 +1,3 @@
|
||||
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
rllama::rllama_main::main()
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
// @generated
|
||||
|
||||
pub mod sentencepiece_model;
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,102 @@
|
||||
use crate::embedding::Embedding;
|
||||
use crate::token_sampler::TokenSampler;
|
||||
use crate::tokenizer::{TokenId, Tokenizer};
|
||||
use crate::transformer::Transformer;
|
||||
use crate::unpickler;
|
||||
use clap::Parser;
|
||||
use std::io::Read;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Cli {
|
||||
#[arg(long)]
|
||||
model_path: String,
|
||||
#[arg(long)]
|
||||
tokenizer_path: String,
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long)]
|
||||
temperature: Option<f32>,
|
||||
#[arg(long)]
|
||||
top_p: Option<f32>,
|
||||
#[arg(long)]
|
||||
top_k: Option<i32>,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let cli = Cli::parse();
|
||||
let model_path = cli.model_path;
|
||||
let tokenizer_path = cli.tokenizer_path;
|
||||
let prompt = cli.prompt;
|
||||
|
||||
println!("Starting up. Loading tokenizer from {}...", tokenizer_path);
|
||||
let tok = Tokenizer::load(tokenizer_path.as_str())?;
|
||||
println!("Tokenizer loeaded. Loading model from {}...", model_path);
|
||||
let mut fs = std::fs::File::open(model_path.as_str())?;
|
||||
let mut bs = Vec::new();
|
||||
fs.read_to_end(&mut bs)?;
|
||||
std::mem::drop(fs);
|
||||
|
||||
// We chop off file name from model_path and append "data/"
|
||||
let model_data_dir = model_path
|
||||
.split("/")
|
||||
.take(model_path.split("/").count() - 1)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("/")
|
||||
+ "/data/";
|
||||
let result = unpickler::unpickle(&bs)?;
|
||||
println!("Loading embeddings from {}...", model_data_dir);
|
||||
let emb = Embedding::from_unpickled(&result, model_data_dir.clone())?;
|
||||
|
||||
println!("Loading transformer weights from {}...", model_data_dir);
|
||||
let tr = Transformer::from_unpickled(
|
||||
&result,
|
||||
emb,
|
||||
4096,
|
||||
32,
|
||||
32,
|
||||
512,
|
||||
1e-6,
|
||||
32,
|
||||
128,
|
||||
model_data_dir,
|
||||
)?;
|
||||
println!("All is loaded. Starting inference.");
|
||||
|
||||
let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt);
|
||||
let mut prev_pos = 0;
|
||||
let mut token_sampler = TokenSampler::new().temperature(0.8).top_p(0.9).top_k(50);
|
||||
|
||||
if let Some(temperature) = cli.temperature {
|
||||
token_sampler = token_sampler.temperature(temperature);
|
||||
}
|
||||
if let Some(top_p) = cli.top_p {
|
||||
token_sampler = token_sampler.top_p(top_p);
|
||||
}
|
||||
if let Some(top_k) = cli.top_k {
|
||||
token_sampler = token_sampler.top_k(top_k as usize);
|
||||
}
|
||||
|
||||
println!("Temperature: {}", token_sampler.get_temperature());
|
||||
println!("Top P: {}", token_sampler.get_top_p());
|
||||
println!("Top K: {}", token_sampler.get_top_k());
|
||||
|
||||
let mut caches = tr.make_caches();
|
||||
loop {
|
||||
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);
|
||||
let highest_pred_idx = token_sampler.sample(&preds);
|
||||
toks_id.push(highest_pred_idx as TokenId);
|
||||
prev_pos = toks_id.len() - 1;
|
||||
|
||||
let mut tok_str: String = "".to_string();
|
||||
for tok_id in toks_id.iter() {
|
||||
if *tok_id == 1 {
|
||||
continue;
|
||||
}
|
||||
let tok = tok.id_to_str(*tok_id);
|
||||
tok_str = tok_str + tok.replace("▁", " ").as_str();
|
||||
}
|
||||
println!("{}", tok_str);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,85 @@
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tokenizer::TokenId;
|
||||
use rand::Rng;
|
||||
|
||||
pub struct TokenSampler {
|
||||
temperature: f32,
|
||||
top_p: f32,
|
||||
top_k: usize,
|
||||
}
|
||||
|
||||
impl TokenSampler {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
temperature: 0.8,
|
||||
top_p: 1.0,
|
||||
top_k: 1, // same as argmax
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_temperature(&self) -> f32 {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
pub fn get_top_p(&self) -> f32 {
|
||||
self.top_p
|
||||
}
|
||||
|
||||
pub fn get_top_k(&self) -> usize {
|
||||
self.top_k
|
||||
}
|
||||
|
||||
pub fn temperature(self, temperature: f32) -> Self {
|
||||
Self {
|
||||
temperature,
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn top_p(self, top_p: f32) -> Self {
|
||||
Self { top_p, ..self }
|
||||
}
|
||||
|
||||
pub fn top_k(self, top_k: usize) -> Self {
|
||||
Self { top_k, ..self }
|
||||
}
|
||||
|
||||
pub fn sample(&self, logits: &Tensor) -> TokenId {
|
||||
let nrows = logits.rows();
|
||||
assert!(logits.cols() == 1);
|
||||
let mut logits = logits.transpose();
|
||||
if self.temperature > 0.0 {
|
||||
logits = logits.scalar_multiply_f32(1.0 / self.temperature);
|
||||
logits = logits.softmax();
|
||||
}
|
||||
|
||||
let mut logitsf: Vec<(TokenId, f32)> = Vec::with_capacity(nrows as usize);
|
||||
for i in 0..nrows {
|
||||
logitsf.push((i as TokenId, logits.get_f32(0, i)));
|
||||
}
|
||||
logitsf.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
logitsf.truncate(self.top_k as usize);
|
||||
let mut p_accum: f32 = 0.0;
|
||||
for (idx, v) in logitsf.iter().enumerate() {
|
||||
p_accum += v.1;
|
||||
if p_accum >= self.top_p {
|
||||
logitsf.truncate(idx + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let mut total_p: f32 = 0.0;
|
||||
for v in logitsf.iter() {
|
||||
total_p += v.1;
|
||||
}
|
||||
let mut rng = rand::thread_rng();
|
||||
let p: f32 = rng.gen_range(0.0..total_p);
|
||||
p_accum = 0.0;
|
||||
for v in logitsf.into_iter() {
|
||||
p_accum += v.1;
|
||||
if p_accum >= p {
|
||||
return v.0;
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,156 @@
|
||||
use crate::protomodels::sentencepiece_model::model_proto::sentence_piece;
|
||||
use crate::protomodels::sentencepiece_model::ModelProto;
|
||||
use protobuf::Message;
|
||||
use std::collections::BTreeMap;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
|
||||
pub type TokenId = i32;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Tokenizer {
|
||||
pieces: BTreeMap<String, Piece>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Ord, PartialEq, PartialOrd)]
|
||||
pub enum PieceType {
|
||||
Normal,
|
||||
Unknown,
|
||||
Control,
|
||||
UserDefined,
|
||||
Byte,
|
||||
Unused,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Piece {
|
||||
_tp: PieceType,
|
||||
// piece: String this is in the BTreeMap that holds the pieces
|
||||
_score: f32,
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TokenizerError {
|
||||
#[error("IO error")]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error("Protobuf error")]
|
||||
ProtobufError(#[from] protobuf::Error),
|
||||
#[error("Unknown piece type")]
|
||||
UnknownPieceType(String),
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Tokenizer, TokenizerError> {
|
||||
let mut fs = std::fs::File::open(path)?;
|
||||
let mut buffer = Vec::new();
|
||||
fs.read_to_end(&mut buffer)?;
|
||||
std::mem::drop(fs);
|
||||
let model = ModelProto::parse_from_bytes(&buffer)?;
|
||||
|
||||
let mut pieces = BTreeMap::new();
|
||||
for (idx, piece) in model.pieces.iter().enumerate() {
|
||||
let piece_str = piece.piece.clone();
|
||||
if piece_str.is_none() {
|
||||
continue;
|
||||
}
|
||||
let piece_str = piece_str.unwrap();
|
||||
let piece_type = match piece.type_ {
|
||||
None => sentence_piece::Type::NORMAL,
|
||||
Some(v) => match v.enum_value() {
|
||||
Err(_) => return Err(TokenizerError::UnknownPieceType(piece_str)),
|
||||
Ok(v) => v,
|
||||
},
|
||||
};
|
||||
|
||||
let score = piece.score.unwrap_or(0.0);
|
||||
let tp = if piece_type == sentence_piece::Type::NORMAL {
|
||||
PieceType::Normal
|
||||
} else if piece_type == sentence_piece::Type::UNKNOWN {
|
||||
PieceType::Unknown
|
||||
} else if piece_type == sentence_piece::Type::CONTROL {
|
||||
PieceType::Control
|
||||
} else if piece_type == sentence_piece::Type::USER_DEFINED {
|
||||
PieceType::UserDefined
|
||||
} else if piece_type == sentence_piece::Type::BYTE {
|
||||
PieceType::Byte
|
||||
} else if piece_type == sentence_piece::Type::UNUSED {
|
||||
PieceType::Unused
|
||||
} else {
|
||||
return Err(TokenizerError::UnknownPieceType(piece_str));
|
||||
};
|
||||
pieces.insert(
|
||||
piece_str,
|
||||
Piece {
|
||||
_tp: tp,
|
||||
_score: score,
|
||||
idx,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Tokenizer { pieces })
|
||||
}
|
||||
|
||||
// Gives a string for a token id.
|
||||
// Panics if the id is out of range.
|
||||
pub fn id_to_str(&self, id: i32) -> &str {
|
||||
let id = id as usize;
|
||||
for (piece_str, piece_info) in self.pieces.iter() {
|
||||
if piece_info.idx == id {
|
||||
return piece_str;
|
||||
}
|
||||
}
|
||||
panic!("id out of range");
|
||||
}
|
||||
|
||||
// Converts a string to a Vec<&str>
|
||||
// You may want to use tokenize_to_ids instead.
|
||||
//
|
||||
// This will not add start or end tokens; only the string is processed.
|
||||
//
|
||||
// I noticed LLaMa code adds an extra space character at the beginning of the string, this
|
||||
// function does not do that either.
|
||||
pub fn tokenize_to_pieces<S: AsRef<str>>(&self, s: S) -> Vec<&str> {
|
||||
let mut s: &str = s.as_ref();
|
||||
let mut result: Vec<&str> = Vec::new();
|
||||
|
||||
// Very naive matching
|
||||
while !s.is_empty() {
|
||||
let mut best_candidate: &str = "";
|
||||
let mut best_candidate_len: usize = 0;
|
||||
let mut skip_s: &str = "";
|
||||
for (piece_str, _piece_info) in self.pieces.iter() {
|
||||
if s.starts_with(piece_str) && best_candidate_len < piece_str.len() {
|
||||
best_candidate = piece_str;
|
||||
best_candidate_len = piece_str.len();
|
||||
skip_s = &s[piece_str.len()..];
|
||||
}
|
||||
}
|
||||
if best_candidate_len == 0 {
|
||||
// Skip token.
|
||||
s = s.get(1..).unwrap_or("");
|
||||
} else {
|
||||
result.push(best_candidate);
|
||||
s = skip_s;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn tokenize_to_ids<S: AsRef<str>>(&self, s: S) -> Vec<TokenId> {
|
||||
let mut s: String = format!("▁{}", s.as_ref());
|
||||
// Replace all space characters with a special token.
|
||||
s = s.replace(" ", "▁");
|
||||
|
||||
let pieces = self.tokenize_to_pieces(s);
|
||||
let mut result = Vec::new();
|
||||
result.push(1); // start token
|
||||
for piece in pieces {
|
||||
let piece_info = self.pieces.get(piece).unwrap();
|
||||
result.push(piece_info.idx as i32);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,546 @@
|
||||
use crate::embedding::Embedding;
|
||||
use crate::tensor::{Tensor, TensorDType};
|
||||
use crate::tokenizer::TokenId;
|
||||
use crate::unpickler;
|
||||
use crate::unpickler::UnpicklingError;
|
||||
use indicatif::ProgressBar;
|
||||
use num_complex::Complex;
|
||||
use rayon::prelude::*;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
type FreqsCis = Vec<Vec<Complex<f64>>>;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct Transformer {
|
||||
freqs_cis: FreqsCis,
|
||||
emb: Embedding,
|
||||
dim: usize,
|
||||
n_layers: usize,
|
||||
n_heads: usize,
|
||||
n_local_heads: usize,
|
||||
max_seq_len: usize,
|
||||
head_dim: usize,
|
||||
|
||||
norm: RMSNorm,
|
||||
output: Tensor,
|
||||
|
||||
layers: Vec<TransformerBlock>,
|
||||
}
|
||||
|
||||
pub struct TransformerCaches {
|
||||
layer_caches: Vec<AttentionCache>,
|
||||
}
|
||||
|
||||
pub struct TransformerBlock {
|
||||
feed_forward: FeedForward,
|
||||
attn: Attention,
|
||||
ffn_norm: RMSNorm,
|
||||
attention_norm: RMSNorm,
|
||||
}
|
||||
|
||||
pub struct AttentionCache {
|
||||
cache_k: Vec<Arc<RwLock<Tensor>>>,
|
||||
cache_v: Vec<Arc<RwLock<Tensor>>>,
|
||||
}
|
||||
|
||||
impl AttentionCache {
|
||||
fn new(max_seq_len: usize, n_local_heads: usize, head_dim: usize) -> Self {
|
||||
let mut cache_k = Vec::with_capacity(n_local_heads);
|
||||
let mut cache_v = Vec::with_capacity(n_local_heads);
|
||||
for _ in 0..n_local_heads {
|
||||
cache_k.push(Arc::new(RwLock::new(Tensor::zeros(
|
||||
head_dim as i64,
|
||||
max_seq_len as i64,
|
||||
TensorDType::Float32,
|
||||
))));
|
||||
cache_v.push(Arc::new(RwLock::new(Tensor::zeros(
|
||||
head_dim as i64,
|
||||
max_seq_len as i64,
|
||||
TensorDType::Float32,
|
||||
))));
|
||||
}
|
||||
AttentionCache { cache_k, cache_v }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RMSNorm {
|
||||
eps: f64,
|
||||
weight: Tensor,
|
||||
}
|
||||
|
||||
pub struct Attention {
|
||||
wq: Tensor,
|
||||
wk: Tensor,
|
||||
wv: Tensor,
|
||||
wo: Tensor,
|
||||
n_local_heads: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
pub struct FeedForward {
|
||||
w1: Tensor,
|
||||
w2: Tensor,
|
||||
w3: Tensor,
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
pub fn from_unpickled<P: AsRef<Path>>(
|
||||
unpickled: &unpickler::Value,
|
||||
emb: Embedding,
|
||||
dim: usize,
|
||||
n_layers: usize,
|
||||
n_heads: usize,
|
||||
max_seq_len: usize,
|
||||
eps: f64,
|
||||
n_local_heads: usize,
|
||||
head_dim: usize,
|
||||
data_dir: P,
|
||||
) -> Result<Transformer, UnpicklingError> {
|
||||
let data_dir: &Path = data_dir.as_ref();
|
||||
|
||||
let progress_bar = ProgressBar::new(n_layers as u64);
|
||||
let layers: Vec<TransformerBlock> = (0..n_layers)
|
||||
.into_par_iter()
|
||||
.map(|layer_id| {
|
||||
let result = TransformerBlock::from_unpickled(
|
||||
unpickled,
|
||||
layer_id,
|
||||
eps,
|
||||
n_local_heads,
|
||||
head_dim,
|
||||
data_dir,
|
||||
);
|
||||
progress_bar.inc(1);
|
||||
result
|
||||
})
|
||||
.collect::<Result<Vec<TransformerBlock>, UnpicklingError>>()?;
|
||||
std::mem::drop(progress_bar);
|
||||
|
||||
let norm = RMSNorm::from_unpickled(unpickled, format!("norm.weight"), eps, data_dir)?;
|
||||
let output =
|
||||
Tensor::from_unpickled(unpickled, format!("output.weight"), data_dir)?.to_f32();
|
||||
|
||||
Ok(Transformer {
|
||||
freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len * 2, 10000.0),
|
||||
emb,
|
||||
dim,
|
||||
n_layers,
|
||||
n_heads,
|
||||
n_local_heads,
|
||||
max_seq_len,
|
||||
head_dim,
|
||||
|
||||
norm,
|
||||
output,
|
||||
|
||||
layers,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_caches(&self) -> TransformerCaches {
|
||||
let mut result = vec![];
|
||||
for _ in 0..self.n_layers {
|
||||
result.push(AttentionCache::new(
|
||||
self.max_seq_len,
|
||||
self.n_local_heads,
|
||||
self.head_dim,
|
||||
));
|
||||
}
|
||||
TransformerCaches {
|
||||
layer_caches: result,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
tokens: &[TokenId],
|
||||
start_pos: usize,
|
||||
caches: &mut TransformerCaches,
|
||||
) -> Tensor {
|
||||
assert!(caches.layer_caches.len() == self.n_layers);
|
||||
let mask: Option<Tensor> = if tokens.len() > 1 {
|
||||
Some(Tensor::full_triu(
|
||||
tokens.len() as i64,
|
||||
tokens.len() as i64,
|
||||
start_pos as i64 + 1,
|
||||
TensorDType::Float32,
|
||||
std::f32::NEG_INFINITY,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut embs: Vec<&Tensor> = Vec::with_capacity(tokens.len());
|
||||
for token in tokens.iter() {
|
||||
let emb = self.emb.get_embedding(*token as usize);
|
||||
embs.push(emb);
|
||||
}
|
||||
let mut emb_tensor: Tensor = Tensor::concat(&embs);
|
||||
std::mem::drop(embs);
|
||||
|
||||
for (idx, layer) in self.layers.iter().enumerate() {
|
||||
emb_tensor = layer.forward(
|
||||
&emb_tensor,
|
||||
start_pos,
|
||||
&self.freqs_cis,
|
||||
&mask,
|
||||
&mut caches.layer_caches[idx],
|
||||
);
|
||||
}
|
||||
let out = self.norm.forward(&emb_tensor);
|
||||
let out = out.row(out.rows() - 1);
|
||||
let prediction = self.output.matrix_mul_transposed(&out);
|
||||
return prediction;
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
pub fn from_unpickled<P: AsRef<Path>>(
|
||||
unpickled: &unpickler::Value,
|
||||
layer_id: usize,
|
||||
eps: f64,
|
||||
n_local_heads: usize,
|
||||
head_dim: usize,
|
||||
data_dir: P,
|
||||
) -> Result<Self, UnpicklingError> {
|
||||
let data_dir: &Path = data_dir.as_ref();
|
||||
let ff = FeedForward::from_unpickled(unpickled, layer_id, data_dir)?;
|
||||
let attn =
|
||||
Attention::from_unpickled(unpickled, layer_id, n_local_heads, head_dim, data_dir)?;
|
||||
let ffn_norm = RMSNorm::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.ffn_norm.weight", layer_id),
|
||||
eps,
|
||||
data_dir,
|
||||
)?;
|
||||
let attn_norm = RMSNorm::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.attention_norm.weight", layer_id),
|
||||
eps,
|
||||
data_dir,
|
||||
)?;
|
||||
Ok(Self {
|
||||
feed_forward: ff,
|
||||
attn,
|
||||
ffn_norm,
|
||||
attention_norm: attn_norm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
start_pos: usize,
|
||||
freqs_cis: &FreqsCis,
|
||||
mask: &Option<Tensor>,
|
||||
attention_cache: &mut AttentionCache,
|
||||
) -> Tensor {
|
||||
let attnorm_out = self.attention_norm.forward(x);
|
||||
let att_out = self
|
||||
.attn
|
||||
.forward(&attnorm_out, start_pos, freqs_cis, mask, attention_cache);
|
||||
let h = x.add(&att_out);
|
||||
let att_out = self.ffn_norm.forward(&h);
|
||||
let att_out = self.feed_forward.forward(&att_out.transpose()).transpose();
|
||||
let att_out = h.add(&att_out);
|
||||
return att_out;
|
||||
}
|
||||
}
|
||||
|
||||
impl RMSNorm {
|
||||
pub fn from_unpickled<P: AsRef<Path>>(
|
||||
unpickled: &unpickler::Value,
|
||||
name: String,
|
||||
eps: f64,
|
||||
data_dir: P,
|
||||
) -> Result<RMSNorm, UnpicklingError> {
|
||||
let data_dir: &Path = data_dir.as_ref();
|
||||
let weights = Tensor::from_unpickled(unpickled, &name, data_dir)?.to_f32();
|
||||
Ok(Self {
|
||||
eps,
|
||||
weight: weights,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Tensor {
|
||||
let inner = x.pow(2.0).mean_cols().add_scalar(self.eps as f32);
|
||||
let out1 = x.scalar_multiply_broadcast(&inner.rsqrt());
|
||||
return out1.hadamard_product_broadcast(&self.weight);
|
||||
}
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
pub fn from_unpickled<P: AsRef<Path>>(
|
||||
unpickled: &unpickler::Value,
|
||||
layer_id: usize,
|
||||
data_dir: P,
|
||||
) -> Result<FeedForward, UnpicklingError> {
|
||||
let data_dir: &Path = data_dir.as_ref();
|
||||
|
||||
let w1 = Tensor::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.feed_forward.w1.weight", layer_id),
|
||||
data_dir,
|
||||
)?
|
||||
.to_f32();
|
||||
let w2 = Tensor::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.feed_forward.w2.weight", layer_id),
|
||||
data_dir,
|
||||
)?
|
||||
.to_f32();
|
||||
let w3 = Tensor::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.feed_forward.w3.weight", layer_id),
|
||||
data_dir,
|
||||
)?
|
||||
.to_f32();
|
||||
|
||||
Ok(Self { w1, w2, w3 })
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Tensor {
|
||||
let x = x.transpose();
|
||||
let (w1_out, w3_out) = rayon::join(
|
||||
|| self.w1.matrix_mul_transposed(&x),
|
||||
|| self.w3.matrix_mul_transposed(&x),
|
||||
);
|
||||
let w1_out = w1_out.silu();
|
||||
let w1w3_out = w1_out.hadamard_product(&w3_out).transpose();
|
||||
let out = self.w2.matrix_mul_transposed(&w1w3_out);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn from_unpickled<P: AsRef<Path>>(
|
||||
unpickled: &unpickler::Value,
|
||||
layer_id: usize,
|
||||
n_local_heads: usize,
|
||||
head_dim: usize,
|
||||
data_dir: P,
|
||||
) -> Result<Attention, UnpicklingError> {
|
||||
let data_dir: &Path = data_dir.as_ref();
|
||||
|
||||
let wq = Tensor::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.attention.wq.weight", layer_id),
|
||||
data_dir,
|
||||
)?
|
||||
.to_f32();
|
||||
let wk = Tensor::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.attention.wk.weight", layer_id),
|
||||
data_dir,
|
||||
)?
|
||||
.to_f32();
|
||||
let wv = Tensor::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.attention.wv.weight", layer_id),
|
||||
data_dir,
|
||||
)?
|
||||
.to_f32();
|
||||
let wo = Tensor::from_unpickled(
|
||||
unpickled,
|
||||
format!("layers.{}.attention.wo.weight", layer_id),
|
||||
data_dir,
|
||||
)?
|
||||
.to_f32();
|
||||
|
||||
Ok(Self {
|
||||
wq,
|
||||
wk,
|
||||
wv,
|
||||
wo,
|
||||
n_local_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
start_pos: usize,
|
||||
freqs_cis: &FreqsCis,
|
||||
mask: &Option<Tensor>,
|
||||
attention_cache: &mut AttentionCache,
|
||||
) -> Tensor {
|
||||
let seq_len = x.rows();
|
||||
let xq_out = x.matrix_mul_transposed(&self.wq);
|
||||
let xk_out = x.matrix_mul_transposed(&self.wk);
|
||||
let xv_out = x.matrix_mul_transposed(&self.wv);
|
||||
|
||||
let mut xq_views: Vec<Tensor> = Vec::with_capacity(seq_len as usize);
|
||||
let mut xk_views: Vec<Tensor> = Vec::with_capacity(seq_len as usize);
|
||||
let mut xv_views: Vec<Tensor> = Vec::with_capacity(seq_len as usize);
|
||||
|
||||
for idx in 0..seq_len {
|
||||
let xq_row = xq_out
|
||||
.row(idx)
|
||||
.view(self.n_local_heads as i64, self.head_dim as i64);
|
||||
let xk_row = xk_out
|
||||
.row(idx)
|
||||
.view(self.n_local_heads as i64, self.head_dim as i64);
|
||||
let xv_row = xv_out
|
||||
.row(idx)
|
||||
.view(self.n_local_heads as i64, self.head_dim as i64);
|
||||
|
||||
let (xq_row, xk_row) =
|
||||
apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos);
|
||||
|
||||
xq_views.push(xq_row);
|
||||
xk_views.push(xk_row);
|
||||
xv_views.push(xv_row);
|
||||
}
|
||||
|
||||
let output: Vec<Tensor> = (0..self.n_local_heads)
|
||||
.into_par_iter()
|
||||
.map(|idx| {
|
||||
let mut concat_vec: Vec<Tensor> = vec![];
|
||||
for idx2 in 0..seq_len {
|
||||
concat_vec.push(xq_views[idx2 as usize].row(idx as i64));
|
||||
}
|
||||
let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect();
|
||||
let xq_row = Tensor::concat(&concat_vec2);
|
||||
|
||||
concat_vec.truncate(0);
|
||||
for idx2 in 0..seq_len {
|
||||
concat_vec.push(xk_views[idx2 as usize].row(idx as i64));
|
||||
}
|
||||
let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect();
|
||||
let xk_row = Tensor::concat(&concat_vec2).transpose();
|
||||
|
||||
concat_vec.truncate(0);
|
||||
for idx2 in 0..seq_len {
|
||||
concat_vec.push(xv_views[idx2 as usize].row(idx as i64));
|
||||
}
|
||||
let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect();
|
||||
let xv_row = Tensor::concat(&concat_vec2);
|
||||
|
||||
let mut cache_k = attention_cache.cache_k[idx as usize].write().unwrap();
|
||||
let mut cache_v = attention_cache.cache_v[idx as usize].write().unwrap();
|
||||
|
||||
/*
|
||||
let m = xq_row
|
||||
.matrix_mul(&xk_row)
|
||||
.scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt());
|
||||
//println!("mask size: {} {}", mask.rows(), mask.cols());
|
||||
//println!("m size: {} {}", m.rows(), m.cols());
|
||||
let m2 = m.add(mask).to_f32().softmax().matrix_mul(&xv_row);
|
||||
m2
|
||||
println!("xk_row size: {} {}", xk_row.rows(), xk_row.cols());
|
||||
println!("xv_row size: {} {}", xv_row.rows(), xv_row.cols());
|
||||
println!("cache_k size: {} {}", cache_k.rows(), cache_k.cols());
|
||||
panic!("stop");
|
||||
*/
|
||||
|
||||
for pos in start_pos..start_pos + seq_len as usize {
|
||||
for dim in 0..self.head_dim {
|
||||
let k = xk_row.get_f32(dim as i64, (pos - start_pos) as i64);
|
||||
cache_k.set_f32(dim as i64, pos as i64, k);
|
||||
let v = xv_row.get_f32((pos - start_pos) as i64, dim as i64);
|
||||
cache_v.set_f32(dim as i64, pos as i64, v);
|
||||
}
|
||||
}
|
||||
let keys = cache_k.clip_cols((start_pos + seq_len as usize) as usize);
|
||||
let values = cache_v.clip_cols((start_pos + seq_len as usize) as usize);
|
||||
|
||||
let m = xq_row
|
||||
.matrix_mul(&keys)
|
||||
.scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt());
|
||||
let m2 = match mask {
|
||||
Some(ref mask) => m
|
||||
.add(mask)
|
||||
.to_f32()
|
||||
.softmax()
|
||||
.matrix_mul_transposed(&values),
|
||||
None => m.softmax().matrix_mul_transposed(&values),
|
||||
};
|
||||
m2
|
||||
})
|
||||
.collect();
|
||||
|
||||
// convert from 32 matrices of size 8x128 to 8 matrices of size 32x128
|
||||
// or rather 4096x1
|
||||
let output2: Vec<Tensor> = (0..seq_len)
|
||||
.into_par_iter()
|
||||
.map(|idx| {
|
||||
let mut concat_vec: Vec<Tensor> = vec![];
|
||||
for idx2 in 0..self.n_local_heads {
|
||||
concat_vec.push(output[idx2 as usize].row(idx as i64));
|
||||
}
|
||||
let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect();
|
||||
let xq_row = Tensor::concat(&concat_vec2).view(1, 4096);
|
||||
let xq_row = xq_row.matrix_mul_transposed(&self.wo);
|
||||
xq_row
|
||||
})
|
||||
.collect();
|
||||
let output3: Vec<&Tensor> = output2.iter().collect();
|
||||
let output2: Tensor = Tensor::concat(&output3);
|
||||
return output2;
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(
|
||||
xq: &Tensor,
|
||||
xk: &Tensor,
|
||||
freqs_cis: &FreqsCis,
|
||||
seq_idx: usize,
|
||||
start_pos: usize,
|
||||
) -> (Tensor, Tensor) {
|
||||
assert!(xq.cols() % 2 == 0);
|
||||
assert!(xk.cols() % 2 == 0);
|
||||
let mut xq_out: Tensor = xq.clone();
|
||||
let mut xk_out: Tensor = xk.clone();
|
||||
for row in 0..xq.rows() {
|
||||
for col in 0..xq.cols() / 2 {
|
||||
let f_real = freqs_cis[seq_idx + start_pos][col as usize].re as f32;
|
||||
let f_imag = freqs_cis[seq_idx + start_pos][col as usize].im as f32;
|
||||
let xq_real = xq.get_f32(row, col * 2);
|
||||
let xq_imag = xq.get_f32(row, col * 2 + 1);
|
||||
let xk_real = xk.get_f32(row, col * 2);
|
||||
let xk_imag = xk.get_f32(row, col * 2 + 1);
|
||||
|
||||
// multiply with freqs_cis
|
||||
let xq_realpart = xq_real * f_real - xq_imag * f_imag;
|
||||
let xq_imagpart = xq_real * f_imag + xq_imag * f_real;
|
||||
let xk_realpart = xk_real * f_real - xk_imag * f_imag;
|
||||
let xk_imagpart = xk_real * f_imag + xk_imag * f_real;
|
||||
|
||||
xq_out.set_f32(row, col * 2, xq_realpart);
|
||||
xq_out.set_f32(row, col * 2 + 1, xq_imagpart);
|
||||
xk_out.set_f32(row, col * 2, xk_realpart);
|
||||
xk_out.set_f32(row, col * 2 + 1, xk_imagpart);
|
||||
}
|
||||
}
|
||||
return (xq_out, xk_out);
|
||||
}
|
||||
|
||||
fn compute_freqs_cis(dim: usize, end: usize, theta: f64) -> FreqsCis {
|
||||
let mut freqs = Vec::new();
|
||||
for idx in 0..(dim / 2) {
|
||||
let freq = 1.0 / (theta.powf(idx as f64 * 2.0 / dim as f64));
|
||||
freqs.push(freq);
|
||||
}
|
||||
|
||||
let mut result: Vec<Vec<f64>> = Vec::new();
|
||||
for x in 0..end {
|
||||
let mut row = Vec::new();
|
||||
for y in 0..freqs.len() {
|
||||
let freq = freqs[y] * (x as f64);
|
||||
row.push(freq);
|
||||
}
|
||||
result.push(row);
|
||||
}
|
||||
|
||||
let mut resultc: Vec<Vec<Complex<f64>>> = Vec::new();
|
||||
for row in result.into_iter() {
|
||||
let mut rowc = Vec::new();
|
||||
for freq in row {
|
||||
let cis = Complex::from_polar(1.0, freq);
|
||||
rowc.push(cis);
|
||||
}
|
||||
resultc.push(rowc);
|
||||
}
|
||||
resultc
|
||||
}
|
||||
@ -0,0 +1,626 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub struct Unpickler {}
|
||||
|
||||
use crate::tensor::{TensorBuilder, TensorDType, TensorError};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum UnpicklingError {
|
||||
#[error("Unpickling error: {0}")]
|
||||
UnpicklingError(String),
|
||||
#[error("UTF-8 decoding error")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("Missing field")]
|
||||
MissingField(String),
|
||||
#[error("Tensor conversion operation failed")]
|
||||
TensorError(#[from] TensorError),
|
||||
#[error("Data has incorrect format to be converted to a tensor")]
|
||||
InvalidTensorData,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||
pub enum Value {
|
||||
Mark(usize),
|
||||
String(String),
|
||||
Global(String, String), // module name, attribute name
|
||||
Integer64(i64),
|
||||
Tuple(Vec<Value>),
|
||||
PersistentId(Box<Value>),
|
||||
Bool(bool),
|
||||
Reduce(Box<Value>, Box<Value>),
|
||||
Dict(BTreeMap<Value, Value>),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
// Gets a value from a dictionary, assuming Value is a dictionary.
|
||||
//
|
||||
// Returns None if the key is not found, or the value is not a dictionary.
|
||||
pub fn get(&self, key: &Value) -> Option<&Value> {
|
||||
match self {
|
||||
Value::Dict(d) => d.get(key),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// Same as get() but uses a string as key.
|
||||
pub fn get_str_key<S: AsRef<str>>(&self, key: S) -> Option<&Value> {
|
||||
self.get(&Value::String(key.as_ref().to_string()))
|
||||
}
|
||||
|
||||
pub fn get_global(&self) -> Option<(&str, &str)> {
|
||||
match self {
|
||||
Value::Global(module_name, attribute_name) => Some((module_name, attribute_name)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
Value::String(s) => Some(s),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_int64(&self) -> Option<i64> {
|
||||
match self {
|
||||
Value::Integer64(i) => Some(*i),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_persistent_id(&self) -> Option<&Value> {
|
||||
match self {
|
||||
Value::PersistentId(v) => Some(&v),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_tuple(&self) -> Option<&[Value]> {
|
||||
match self {
|
||||
Value::Tuple(v) => Some(&v),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// Assume that the value represents a tensor in PyTorch and return instructions how to actually
|
||||
// load the values.
|
||||
pub fn to_tensor_builder(&self) -> Option<TensorBuilder> {
|
||||
match self {
|
||||
Value::Reduce(call, args) => match **call {
|
||||
Value::Global(ref module_name, ref attribute_name) => {
|
||||
if module_name == "torch._utils" && attribute_name == "_rebuild_tensor_v2" {
|
||||
match **args {
|
||||
Value::Tuple(ref args) => self.to_tensor_builder2(&args),
|
||||
_ => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_tensor_builder2(&self, args: &[Value]) -> Option<TensorBuilder> {
|
||||
if args.len() == 6 {
|
||||
Self::to_tensor_builder2_6items(args)
|
||||
} else if args.len() == 4 {
|
||||
Self::to_tensor_builder2_4items(args)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn to_tensor_builder2_4items(args: &[Value]) -> Option<TensorBuilder> {
|
||||
let storagev: &Value = args[0].get_persistent_id()?;
|
||||
let storage_args: &[Value] = storagev.get_tuple()?;
|
||||
let storage_mark: &str = storage_args[0].get_str()?;
|
||||
if storage_mark != "storage" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (storage_module, storage_type) = storage_args[1].get_global()?;
|
||||
if storage_module != "torch" {
|
||||
return None;
|
||||
}
|
||||
let dtype: TensorDType = match storage_type {
|
||||
"HalfStorage" => TensorDType::Float16,
|
||||
_ => return None,
|
||||
};
|
||||
let storage_filename: &str = storage_args[2].get_str()?;
|
||||
let nitems: i64 = storage_args[4].get_int64()?;
|
||||
|
||||
let offset: i64 = args[1].get_int64()?;
|
||||
if offset != 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let rows: i64 = 1;
|
||||
let cols: i64 = nitems;
|
||||
let row_stride: i64 = cols;
|
||||
if row_stride != cols {
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some(TensorBuilder {
|
||||
src_path: PathBuf::from(storage_filename),
|
||||
dtype,
|
||||
stride: row_stride,
|
||||
rows,
|
||||
cols,
|
||||
nitems,
|
||||
});
|
||||
}
|
||||
|
||||
fn to_tensor_builder2_6items(args: &[Value]) -> Option<TensorBuilder> {
|
||||
let storagev: &Value = args[0].get_persistent_id()?;
|
||||
let storage_args: &[Value] = storagev.get_tuple()?;
|
||||
let storage_mark: &str = storage_args[0].get_str()?;
|
||||
if storage_mark != "storage" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (storage_module, storage_type) = storage_args[1].get_global()?;
|
||||
if storage_module != "torch" {
|
||||
return None;
|
||||
}
|
||||
let dtype: TensorDType = match storage_type {
|
||||
"HalfStorage" => TensorDType::Float16,
|
||||
_ => return None,
|
||||
};
|
||||
let storage_filename: &str = storage_args[2].get_str()?;
|
||||
let nitems: i64 = storage_args[4].get_int64()?;
|
||||
|
||||
let offset: i64 = args[1].get_int64()?;
|
||||
if offset != 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let shape: &[Value] = args[2].get_tuple()?;
|
||||
let stride: &[Value] = args[3].get_tuple()?;
|
||||
|
||||
if shape.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
if stride.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let rows: i64 = shape[0].get_int64()?;
|
||||
let cols: i64 = shape[1].get_int64()?;
|
||||
|
||||
let row_stride: i64 = stride[0].get_int64()?;
|
||||
let col_stride: i64 = stride[1].get_int64()?;
|
||||
|
||||
if col_stride != 1 {
|
||||
return None;
|
||||
}
|
||||
if row_stride != cols {
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some(TensorBuilder {
|
||||
src_path: PathBuf::from(storage_filename),
|
||||
dtype,
|
||||
stride: row_stride,
|
||||
rows,
|
||||
cols,
|
||||
nitems,
|
||||
});
|
||||
|
||||
/* Args should look like this (took random example from debug print) :
|
||||
0 PERSISTENT_ID
|
||||
TUPLE
|
||||
STRING "storage"
|
||||
GLOBAL "torch" "HalfStorage"
|
||||
STRING "0" (filename)
|
||||
STRING "cpu"
|
||||
INTEGER 131072000 (number of items)
|
||||
1 INTEGER 0
|
||||
2 TUPLE
|
||||
INTEGER 32000
|
||||
INTEGER 4096
|
||||
3 TUPLE
|
||||
INTEGER 4096
|
||||
INTEGER 1
|
||||
4 BOOL false (this is about gradient)
|
||||
5 REDUCE (no idea why this is here)
|
||||
GLOBAL "collections" "OrderedDict"
|
||||
TUPLE
|
||||
|
||||
Sometimes arguments 2 and 3 are missing.
|
||||
*/
|
||||
}
|
||||
|
||||
// Print a nice representation of the value to stdout. Used for good old printf debugging.
|
||||
pub fn debug_print(&self) {
|
||||
self.debug_print_go(0);
|
||||
}
|
||||
|
||||
fn debug_print_go(&self, indent: usize) {
|
||||
if indent > 0 {
|
||||
print!("{:indent$}", "", indent = indent);
|
||||
}
|
||||
match self {
|
||||
Value::Mark(_) => {
|
||||
println!("MARK");
|
||||
}
|
||||
Value::String(s) => {
|
||||
println!("STRING {:?}", s);
|
||||
}
|
||||
Value::Global(module_name, attribute_name) => {
|
||||
println!("GLOBAL {:?} {:?}", module_name, attribute_name);
|
||||
}
|
||||
Value::Integer64(i) => {
|
||||
println!("INTEGER {:?}", i);
|
||||
}
|
||||
Value::Tuple(v) => {
|
||||
println!("TUPLE");
|
||||
for i in v {
|
||||
i.debug_print_go(indent + 2);
|
||||
}
|
||||
}
|
||||
Value::PersistentId(v) => {
|
||||
println!("PERSISTENT_ID");
|
||||
v.debug_print_go(indent + 2);
|
||||
}
|
||||
Value::Bool(b) => {
|
||||
println!("BOOL {:?}", b);
|
||||
}
|
||||
Value::Reduce(v1, v2) => {
|
||||
println!("REDUCE");
|
||||
v1.debug_print_go(indent + 2);
|
||||
v2.debug_print_go(indent + 2);
|
||||
}
|
||||
Value::Dict(d) => {
|
||||
println!("DICT");
|
||||
for (k, v) in d {
|
||||
k.debug_print_go(indent + 2);
|
||||
v.debug_print_go(indent + 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unpickle(bytes: &[u8]) -> Result<Value, UnpicklingError> {
|
||||
// The LLaMA file is in pickle 2 format, check that header is there
|
||||
if bytes.len() < 2 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Data is too short to be a pickle".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if bytes[0] != 128 || bytes[1] != 2 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"No magic header using Pickle 2 protocol".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut memo: BTreeMap<u32, Value> = BTreeMap::new();
|
||||
let mut stack: Vec<Value> = vec![];
|
||||
|
||||
// Decode frames
|
||||
let mut bytes: &[u8] = &bytes[2..];
|
||||
while !bytes.is_empty() {
|
||||
let frame_opcode = bytes[0];
|
||||
if frame_opcode == 125 {
|
||||
// empty dict
|
||||
stack.push(Value::Dict(BTreeMap::new()));
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 113 {
|
||||
// binput
|
||||
if bytes.len() < 2 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling BINPUT".to_string(),
|
||||
));
|
||||
}
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling BINPUT".to_string(),
|
||||
));
|
||||
}
|
||||
let key = bytes[1];
|
||||
memo.insert(key as u32, stack.last().unwrap().clone());
|
||||
bytes = &bytes[2..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 40 {
|
||||
// mark
|
||||
stack.push(Value::Mark(stack.len()));
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 88 {
|
||||
// binunicode
|
||||
if bytes.len() < 5 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling BINUNICODE".to_string(),
|
||||
));
|
||||
}
|
||||
let len = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
|
||||
if bytes.len() < 5 + len as usize {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling BINUNICODE".to_string(),
|
||||
));
|
||||
}
|
||||
let string = std::str::from_utf8(&bytes[5..5 + len as usize])?;
|
||||
stack.push(Value::String(string.to_string()));
|
||||
bytes = &bytes[5 + len as usize..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 99 {
|
||||
// global
|
||||
// followed by newline terminated module name and attribute name
|
||||
bytes = &bytes[1..];
|
||||
let mut module_name = String::new();
|
||||
while !bytes.is_empty() && bytes[0] != 10 {
|
||||
module_name.push(bytes[0] as char);
|
||||
bytes = &bytes[1..];
|
||||
if bytes.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling GLOBAL".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
bytes = &bytes[1..];
|
||||
let mut attribute_name = String::new();
|
||||
while !bytes.is_empty() && bytes[0] != 10 {
|
||||
attribute_name.push(bytes[0] as char);
|
||||
bytes = &bytes[1..];
|
||||
if bytes.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling GLOBAL".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
bytes = &bytes[1..];
|
||||
stack.push(Value::Global(module_name, attribute_name));
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 74 {
|
||||
// binint
|
||||
if bytes.len() < 5 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling BININT".to_string(),
|
||||
));
|
||||
}
|
||||
let value = i32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
|
||||
stack.push(Value::Integer64(value as i64));
|
||||
bytes = &bytes[5..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 116 {
|
||||
// tuple
|
||||
let mut tuple = vec![];
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling TUPLE".to_string(),
|
||||
));
|
||||
}
|
||||
let mut ok = false;
|
||||
while !stack.is_empty() {
|
||||
let top = stack.pop().unwrap();
|
||||
if let Value::Mark(_mark) = top {
|
||||
tuple.reverse();
|
||||
stack.push(Value::Tuple(tuple));
|
||||
ok = true;
|
||||
break;
|
||||
}
|
||||
tuple.push(top);
|
||||
}
|
||||
if !ok {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"No mark while handling TUPLE".to_string(),
|
||||
));
|
||||
}
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 81 {
|
||||
// binpersid
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling BINPERSID".to_string(),
|
||||
));
|
||||
}
|
||||
let top = stack.pop().unwrap();
|
||||
stack.push(Value::PersistentId(Box::new(top)));
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 75 {
|
||||
// binint1
|
||||
if bytes.len() < 2 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling BININT1".to_string(),
|
||||
));
|
||||
}
|
||||
let value = bytes[1];
|
||||
stack.push(Value::Integer64(value as i64));
|
||||
bytes = &bytes[2..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 77 {
|
||||
// binint2
|
||||
if bytes.len() < 3 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling BININT2".to_string(),
|
||||
));
|
||||
}
|
||||
let value = i16::from_le_bytes([bytes[1], bytes[2]]);
|
||||
stack.push(Value::Integer64(value as i64));
|
||||
bytes = &bytes[3..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 134 {
|
||||
// tuple2
|
||||
let mut tuple = vec![];
|
||||
if stack.len() < 2 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack does not have enough items while handling TUPLE2".to_string(),
|
||||
));
|
||||
}
|
||||
tuple.push(stack.pop().unwrap());
|
||||
tuple.push(stack.pop().unwrap());
|
||||
tuple.reverse();
|
||||
stack.push(Value::Tuple(tuple));
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 137 {
|
||||
// newfalse
|
||||
stack.push(Value::Bool(false));
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 41 {
|
||||
// empty tuple
|
||||
stack.push(Value::Tuple(vec![]));
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 82 {
|
||||
// reduce
|
||||
if stack.len() < 2 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack does not have enough items while handling REDUCE".to_string(),
|
||||
));
|
||||
}
|
||||
let arg_tuple = stack.pop().unwrap();
|
||||
let callable = stack.pop().unwrap();
|
||||
stack.push(Value::Reduce(Box::new(callable), Box::new(arg_tuple)));
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 104 {
|
||||
// binget
|
||||
if bytes.len() < 2 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling BINGET".to_string(),
|
||||
));
|
||||
}
|
||||
let idx = bytes[1];
|
||||
match memo.get(&(idx as u32)) {
|
||||
None => {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"BINGET index out of range".to_string(),
|
||||
));
|
||||
}
|
||||
Some(memo_value) => {
|
||||
stack.push(memo_value.clone());
|
||||
}
|
||||
}
|
||||
bytes = &bytes[2..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 133 {
|
||||
// tuple1
|
||||
let mut tuple = vec![];
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling TUPLE1".to_string(),
|
||||
));
|
||||
}
|
||||
tuple.push(stack.pop().unwrap());
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 114 {
|
||||
// long binput
|
||||
if bytes.len() < 5 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected end of data while handling LONG_BINPUT".to_string(),
|
||||
));
|
||||
}
|
||||
let key = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling LONG_BINPUT".to_string(),
|
||||
));
|
||||
}
|
||||
memo.insert(key as u32, stack.last().unwrap().clone());
|
||||
bytes = &bytes[5..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 117 {
|
||||
// setitems
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling SETITEMS".to_string(),
|
||||
));
|
||||
}
|
||||
let mut ok = false;
|
||||
let mut keyvalues: BTreeMap<Value, Value> = BTreeMap::new();
|
||||
while !stack.is_empty() {
|
||||
let value = stack.pop().unwrap();
|
||||
if let Value::Mark(_mark) = value {
|
||||
ok = true;
|
||||
break;
|
||||
}
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling SETITEMS".to_string(),
|
||||
));
|
||||
}
|
||||
let key = stack.pop().unwrap();
|
||||
if let Value::Mark(_mark) = key {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Unexpected mark while handling SETITEMS".to_string(),
|
||||
));
|
||||
}
|
||||
keyvalues.insert(key, value);
|
||||
}
|
||||
if !ok {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"No mark while handling SETITEMS".to_string(),
|
||||
));
|
||||
}
|
||||
if stack.is_empty() {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack is empty while handling SETITEMS".to_string(),
|
||||
));
|
||||
}
|
||||
let mut dict = stack.pop().unwrap();
|
||||
match dict {
|
||||
Value::Dict(ref mut dict) => {
|
||||
for (key, value) in keyvalues {
|
||||
dict.insert(key, value);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"SETITEMS on non-dict".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
stack.push(dict);
|
||||
bytes = &bytes[1..];
|
||||
continue;
|
||||
}
|
||||
if frame_opcode == 46 {
|
||||
// stop
|
||||
// bytes = &bytes[1..];
|
||||
break;
|
||||
}
|
||||
return Err(UnpicklingError::UnpicklingError(format!(
|
||||
"Unknown opcode: {}",
|
||||
frame_opcode
|
||||
)));
|
||||
}
|
||||
|
||||
// Stack should have just one item, our final value
|
||||
if stack.len() != 1 {
|
||||
return Err(UnpicklingError::UnpicklingError(
|
||||
"Stack does not have exactly one item after unpickling".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(stack.pop().unwrap())
|
||||
}
|
||||
Loading…
Reference in New Issue