Compare commits
No commits in common. 'main' and '0.02' have entirely different histories.
@ -1,201 +1,25 @@
|
|||||||
Apache License
|
BSD 2-Clause License
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
Copyright (c) 2021, PENG Bo
|
||||||
|
All rights reserved.
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
1. Definitions.
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
the copyright owner that is granting the License.
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
control with that entity. For the purposes of this definition,
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
direction or management of such entity, whether by contract or
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
exercising permissions granted by this License.
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 161 KiB |
|
Before Width: | Height: | Size: 67 KiB |
|
Before Width: | Height: | Size: 410 KiB |
|
Before Width: | Height: | Size: 359 KiB |
|
Before Width: | Height: | Size: 69 KiB |
|
Before Width: | Height: | Size: 90 KiB |
|
Before Width: | Height: | Size: 66 KiB |
|
Before Width: | Height: | Size: 55 KiB |
|
Before Width: | Height: | Size: 143 KiB |
|
Before Width: | Height: | Size: 649 KiB |
|
Before Width: | Height: | Size: 289 KiB |
@ -1,172 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
|
|
||||||
|
|
||||||
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
|
|
||||||
const F eps, const int B, const int C, const int T) {
|
|
||||||
const int i = blockIdx.y;
|
|
||||||
const int ij = (B * C) / BF;
|
|
||||||
const int t = threadIdx.x << 2;
|
|
||||||
|
|
||||||
__shared__ F ww[Tmax];
|
|
||||||
__shared__ F kk[Tmax * BF];
|
|
||||||
F4(ww, t) = F4(__w, t + T * (i % C));
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float4 s[BF];
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
s[j] = {eps, eps, eps, eps};
|
|
||||||
}
|
|
||||||
const F *__restrict__ const w = ww + T - t - 4;
|
|
||||||
for (int u = 0; u <= t; u++) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
const F x = kk[u + Tmax * j];
|
|
||||||
s[j].x += w[u + 3] * x;
|
|
||||||
s[j].y += w[u + 2] * x;
|
|
||||||
s[j].z += w[u + 1] * x;
|
|
||||||
s[j].w += w[u + 0] * x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
const F *__restrict__ const k = kk + Tmax * j;
|
|
||||||
s[j].y += w[t + 3] * k[t + 1];
|
|
||||||
s[j].z += w[t + 2] * k[t + 1];
|
|
||||||
s[j].z += w[t + 3] * k[t + 2];
|
|
||||||
s[j].w += w[t + 1] * k[t + 1];
|
|
||||||
s[j].w += w[t + 2] * k[t + 2];
|
|
||||||
s[j].w += w[t + 3] * k[t + 3];
|
|
||||||
F4(x, t + T * (i + ij * j)) = s[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
|
|
||||||
F *__restrict__ const gw, F *__restrict__ const gk,
|
|
||||||
const int B, const int C, const int T) {
|
|
||||||
const int i = blockIdx.y;
|
|
||||||
const int t = threadIdx.x << 2;
|
|
||||||
|
|
||||||
__shared__ F k[Tmax];
|
|
||||||
__shared__ F gg[Tmax];
|
|
||||||
F4(k, t) = F4(__k, t + T * i);
|
|
||||||
F4(gg, t) = F4(__gwk, t + T * i);
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float4 s = {0, 0, 0, 0};
|
|
||||||
|
|
||||||
const F *__restrict__ const g = gg + T - t - 4;
|
|
||||||
for (int u = 0; u <= t; u++) {
|
|
||||||
F x = k[u];
|
|
||||||
s.x += g[u + 3] * x;
|
|
||||||
s.y += g[u + 2] * x;
|
|
||||||
s.z += g[u + 1] * x;
|
|
||||||
s.w += g[u + 0] * x;
|
|
||||||
}
|
|
||||||
s.y += g[t + 3] * k[t + 1];
|
|
||||||
s.z += g[t + 2] * k[t + 1];
|
|
||||||
s.z += g[t + 3] * k[t + 2];
|
|
||||||
s.w += g[t + 1] * k[t + 1];
|
|
||||||
s.w += g[t + 2] * k[t + 2];
|
|
||||||
s.w += g[t + 3] * k[t + 3];
|
|
||||||
F4(gw, t + T * i) = s;
|
|
||||||
}
|
|
||||||
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
|
|
||||||
dim3 gridDim(1, B * C / BF);
|
|
||||||
dim3 blockDim(T >> 2);
|
|
||||||
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
|
|
||||||
F *__restrict__ const gw, F *__restrict__ const gk,
|
|
||||||
const int B, const int C, const int T) {
|
|
||||||
const int i = blockIdx.y;
|
|
||||||
const int ij = (B * C) / BB;
|
|
||||||
const int t = threadIdx.x << 2;
|
|
||||||
|
|
||||||
__shared__ F w[Tmax];
|
|
||||||
__shared__ F kk[Tmax * BB];
|
|
||||||
__shared__ F gg[Tmax * BB];
|
|
||||||
F4(w, t) = F4(__w, t + T * (i % C));
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
|
|
||||||
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float4 s[BB];
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
s[j] = {0, 0, 0, 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int u = 0; u <= t; u++) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
|
|
||||||
F x = kk[u + Tmax * j];
|
|
||||||
s[j].x += g[u + 3] * x;
|
|
||||||
s[j].y += g[u + 2] * x;
|
|
||||||
s[j].z += g[u + 1] * x;
|
|
||||||
s[j].w += g[u + 0] * x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const k = kk + Tmax * j;
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
|
|
||||||
s[j].y += g[t + 3] * k[t + 1];
|
|
||||||
s[j].z += g[t + 2] * k[t + 1];
|
|
||||||
s[j].z += g[t + 3] * k[t + 2];
|
|
||||||
s[j].w += g[t + 1] * k[t + 1];
|
|
||||||
s[j].w += g[t + 2] * k[t + 2];
|
|
||||||
s[j].w += g[t + 3] * k[t + 3];
|
|
||||||
F4(gw, t + T * (i + ij * j)) = s[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
s[j] = {0, 0, 0, 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int u = t + 3; u < T; u++) {
|
|
||||||
F x = w[u];
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
|
|
||||||
s[j].x += g[2 - u] * x;
|
|
||||||
s[j].y += g[3 - u] * x;
|
|
||||||
s[j].z += g[4 - u] * x;
|
|
||||||
s[j].w += g[5 - u] * x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
|
|
||||||
s[j].x += g[2 - t] * w[t + 0];
|
|
||||||
s[j].x += g[1 - t] * w[t + 1];
|
|
||||||
s[j].x += g[0 - t] * w[t + 2];
|
|
||||||
s[j].y += g[2 - t] * w[t + 1];
|
|
||||||
s[j].y += g[1 - t] * w[t + 2];
|
|
||||||
s[j].z += g[2 - t] * w[t + 2];
|
|
||||||
F4(gk, t + T * (i + ij * j)) = s[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
|
|
||||||
dim3 gridDim(1, B * C / BB);
|
|
||||||
dim3 blockDim(T >> 2);
|
|
||||||
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
|
|
||||||
}
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
|
|
||||||
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
|
|
||||||
|
|
||||||
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
|
|
||||||
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
|
|
||||||
}
|
|
||||||
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
|
|
||||||
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &forward, "timex forward");
|
|
||||||
m.def("backward", &backward, "timex backward");
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY(timex, m) {
|
|
||||||
m.def("forward", forward);
|
|
||||||
m.def("backward", backward);
|
|
||||||
}
|
|
||||||
@ -1,133 +0,0 @@
|
|||||||
# -*- coding:utf-8 -*-
|
|
||||||
########################################################################################################
|
|
||||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from src.utils import TOKENIZER, Dataset
|
|
||||||
from src.model_run import RWKV_RNN
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
|
|
||||||
### Step 1: set model ##################################################################################
|
|
||||||
|
|
||||||
ctx_len = 1024
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
|
|
||||||
|
|
||||||
# your trained model
|
|
||||||
MODEL_NAME = 'trained-31'
|
|
||||||
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
|
|
||||||
|
|
||||||
# ########## Uncomment these to test my 27M params enwik8 model ##########
|
|
||||||
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
|
|
||||||
# WORD_NAME = 'enwik8-vocab'
|
|
||||||
# EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation)
|
|
||||||
# ########################################################################
|
|
||||||
|
|
||||||
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
|
|
||||||
# --> all unknown tokens in your context will be denoted by it <--
|
|
||||||
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
|
|
||||||
|
|
||||||
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
|
|
||||||
DEBUG_DEBUG = False # True False - show softmax output
|
|
||||||
|
|
||||||
### Step 2: set context ################################################################################
|
|
||||||
|
|
||||||
context = "\nIn the" # ==> this is your prompt
|
|
||||||
|
|
||||||
NUM_TRIALS = 999
|
|
||||||
LENGTH_PER_TRIAL = 500
|
|
||||||
|
|
||||||
TEMPERATURE = 1.0
|
|
||||||
top_p = 0.7
|
|
||||||
top_p_newline = 0.9
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print(f'Loading {MODEL_NAME}...')
|
|
||||||
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
|
|
||||||
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
|
|
||||||
print('Evaluating on ' + EVAL_DATA + ' ...')
|
|
||||||
|
|
||||||
data = open(EVAL_DATA, "r", encoding='utf-8').read()
|
|
||||||
|
|
||||||
loss_table = np.zeros(ctx_len)
|
|
||||||
|
|
||||||
N_SAMPLE = 1000
|
|
||||||
|
|
||||||
for iii in range(N_SAMPLE):
|
|
||||||
pos = np.random.randint(0, len(data) - ctx_len-1)
|
|
||||||
context = data[pos:pos+ctx_len+1]
|
|
||||||
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
||||||
|
|
||||||
model.clear()
|
|
||||||
for i in range(1, ctx_len+1):
|
|
||||||
x = ctx[:i]
|
|
||||||
out = model.run(x)
|
|
||||||
prob = F.softmax(torch.tensor(out), dim=-1)
|
|
||||||
loss_table[i-1] += -math.log(prob[ctx[i]])
|
|
||||||
|
|
||||||
print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
|
|
||||||
np.mean(loss_table) / (iii+1))
|
|
||||||
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
context = tokenizer.refine_context(context)
|
|
||||||
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
|
|
||||||
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
|
|
||||||
|
|
||||||
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
|
||||||
t_begin = time.time_ns()
|
|
||||||
|
|
||||||
src_len = len(context)
|
|
||||||
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
||||||
print(('-' * 30) + context, end='')
|
|
||||||
|
|
||||||
model.clear()
|
|
||||||
if TRIAL == 0:
|
|
||||||
init_state = types.SimpleNamespace()
|
|
||||||
for i in range(src_len):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
if i == src_len - 1:
|
|
||||||
init_state.out = model.run(x)
|
|
||||||
else:
|
|
||||||
model.run(x)
|
|
||||||
model.save(init_state)
|
|
||||||
else:
|
|
||||||
model.load(init_state)
|
|
||||||
|
|
||||||
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
x = x[-ctx_len:]
|
|
||||||
|
|
||||||
if i == src_len:
|
|
||||||
out = copy.deepcopy(init_state.out)
|
|
||||||
else:
|
|
||||||
out = model.run(x)
|
|
||||||
if DEBUG_DEBUG:
|
|
||||||
print('model', np.array(x), '==>', np.array(
|
|
||||||
out), np.max(out), np.min(out))
|
|
||||||
|
|
||||||
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
|
|
||||||
top_p_usual=top_p, top_p_newline=top_p_newline)
|
|
||||||
char = char.item()
|
|
||||||
print(tokenizer.itos[int(char)], end='', flush=True)
|
|
||||||
ctx += [char]
|
|
||||||
t_end = time.time_ns()
|
|
||||||
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')
|
|
||||||
@ -1,349 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# CUDA Kernel
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
T_MAX = 1024 # increase this if your ctx_len > 1024
|
|
||||||
B_GROUP_FORWARD = 4 # set to 8 for best performance
|
|
||||||
B_GROUP_BACKWARD = 2 # set to 2 for best performance
|
|
||||||
|
|
||||||
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
|
|
||||||
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
|
|
||||||
|
|
||||||
|
|
||||||
class TimeX(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, w, k, B, C, T, eps):
|
|
||||||
ctx.B = B
|
|
||||||
ctx.C = C
|
|
||||||
ctx.T = T
|
|
||||||
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
|
|
||||||
w = w.contiguous()
|
|
||||||
k = k.contiguous()
|
|
||||||
ctx.save_for_backward(w, k)
|
|
||||||
wk = torch.empty((B, C, T), device='cuda',
|
|
||||||
memory_format=torch.contiguous_format)
|
|
||||||
timex_cuda.forward(w, k, wk, eps, B, C, T)
|
|
||||||
return wk
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gwk):
|
|
||||||
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
|
|
||||||
w, k = ctx.saved_tensors
|
|
||||||
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
|
|
||||||
memory_format=torch.contiguous_format)
|
|
||||||
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
|
|
||||||
memory_format=torch.contiguous_format)
|
|
||||||
timex_cuda.backward(w, k, gwk.contiguous(), gw,
|
|
||||||
gk, ctx.B, ctx.C, ctx.T)
|
|
||||||
return (gw.sum(dim=0), gk, None, None, None, None)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
RWKV_K_CLAMP = 60 # e^60 = 1e26
|
|
||||||
RWKV_K_EPS = 1e-16
|
|
||||||
RWKV_HEAD_QK_DIM = 256
|
|
||||||
|
|
||||||
|
|
||||||
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
|
|
||||||
for m in module.modules():
|
|
||||||
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
||||||
continue
|
|
||||||
with torch.no_grad():
|
|
||||||
name = '[unknown weight]'
|
|
||||||
for name, parameter in module.named_parameters(): # find the name of the weight
|
|
||||||
if id(m.weight) == id(parameter):
|
|
||||||
break
|
|
||||||
|
|
||||||
shape = m.weight.data.shape
|
|
||||||
gain = 1.0
|
|
||||||
scale = 1.0 # extra scale for gain
|
|
||||||
|
|
||||||
if isinstance(m, nn.Embedding):
|
|
||||||
gain = math.sqrt(max(shape[0], shape[1]))
|
|
||||||
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
||||||
scale = 1e-4
|
|
||||||
else:
|
|
||||||
scale = 0
|
|
||||||
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
if shape[0] > shape[1]:
|
|
||||||
gain = math.sqrt(shape[0] / shape[1])
|
|
||||||
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
||||||
scale = 0.5
|
|
||||||
|
|
||||||
if hasattr(m, 'scale_init'):
|
|
||||||
scale = m.scale_init
|
|
||||||
|
|
||||||
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
|
||||||
|
|
||||||
gain *= scale
|
|
||||||
if scale == -999:
|
|
||||||
nn.init.eye_(m.weight)
|
|
||||||
elif gain == 0:
|
|
||||||
# zero init is great for some RWKV matrices
|
|
||||||
nn.init.zeros_(m.weight)
|
|
||||||
elif gain > 0:
|
|
||||||
nn.init.orthogonal_(m.weight, gain=gain)
|
|
||||||
else:
|
|
||||||
nn.init.normal_(m.weight, mean=0.0, std=-scale)
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_TimeMix(nn.Module):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.ctx_len = config.ctx_len
|
|
||||||
self.n_embd = config.n_embd
|
|
||||||
|
|
||||||
attn_sz = config.n_embd
|
|
||||||
|
|
||||||
############# fancy init of time_w curves ###################################
|
|
||||||
f1_begin = 3.0
|
|
||||||
f1_end = 1.2
|
|
||||||
f2_begin = 0.65
|
|
||||||
f2_end = 0.4
|
|
||||||
with torch.no_grad(): # initial time_w curves for better convergence
|
|
||||||
decay_speed = torch.ones(attn_sz, 1)
|
|
||||||
first_sa_layer_id = 1
|
|
||||||
for h in range(attn_sz):
|
|
||||||
f1 = f1_begin + (layer_id-first_sa_layer_id) / \
|
|
||||||
(config.n_layer-1-first_sa_layer_id) * (f1_end - f1_begin)
|
|
||||||
f2 = f2_begin + (layer_id-first_sa_layer_id) / \
|
|
||||||
(config.n_layer-1-first_sa_layer_id) * (f2_end - f2_begin)
|
|
||||||
if layer_id == first_sa_layer_id:
|
|
||||||
f1 += 0.5
|
|
||||||
if layer_id == config.n_layer-2:
|
|
||||||
f2 = 0.4
|
|
||||||
if layer_id == config.n_layer-1:
|
|
||||||
f2 = 0.37
|
|
||||||
decay_speed[h][0] = math.pow(f2, h / (attn_sz-1) * 7) * f1
|
|
||||||
self.time_decay = nn.Parameter(torch.log(decay_speed)) # will use exp(self.time_decay) to ensure time_decay > 0
|
|
||||||
self.time_curve = torch.tensor(
|
|
||||||
[-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
|
|
||||||
self.time_curve = self.time_curve.to('cuda')
|
|
||||||
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3))
|
|
||||||
#############################################################################
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
with torch.no_grad(): # init to "shift half of the channels"
|
|
||||||
ww = torch.ones(1, 1, config.n_embd)
|
|
||||||
for i in range(config.n_embd // 2):
|
|
||||||
ww[0, 0, i] = 0
|
|
||||||
self.time_mix = nn.Parameter(ww)
|
|
||||||
|
|
||||||
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
|
|
||||||
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.key.scale_init = 0
|
|
||||||
self.receptance.scale_init = 0
|
|
||||||
self.output.scale_init = 0
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T, C = x.size()
|
|
||||||
|
|
||||||
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
|
|
||||||
|
|
||||||
k = self.key(x).transpose(-1, -2)
|
|
||||||
v = self.value(x).transpose(-1, -2)
|
|
||||||
r = self.receptance(x)
|
|
||||||
|
|
||||||
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
|
|
||||||
k = torch.clamp(k, max=RWKV_K_CLAMP)
|
|
||||||
k = torch.exp(k)
|
|
||||||
kv = k * v
|
|
||||||
|
|
||||||
self.time_w = torch.cat(
|
|
||||||
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
|
|
||||||
w = torch.exp(self.time_w)
|
|
||||||
|
|
||||||
wkv = TimeX.apply(w, kv, B, C, T, 0)
|
|
||||||
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
|
|
||||||
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
|
|
||||||
|
|
||||||
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
|
|
||||||
rwkv = self.output(rwkv)
|
|
||||||
return rwkv
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_ChannelMix(nn.Module):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
|
|
||||||
with torch.no_grad(): # init to "shift half of the channels"
|
|
||||||
x = torch.ones(1, 1, config.n_embd)
|
|
||||||
for i in range(config.n_embd // 2):
|
|
||||||
x[0, 0, i] = 0
|
|
||||||
self.time_mix = nn.Parameter(x)
|
|
||||||
|
|
||||||
hidden_sz = 4 * config.n_embd
|
|
||||||
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.value.scale_init = 0
|
|
||||||
self.receptance.scale_init = 0
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
|
|
||||||
|
|
||||||
k = self.key(x)
|
|
||||||
k = torch.square(torch.relu(k))
|
|
||||||
kv = self.value(k)
|
|
||||||
|
|
||||||
rkv = torch.sigmoid(self.receptance(x)) * kv
|
|
||||||
return rkv
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# The GPT Model with our blocks
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
class GPTConfig:
|
|
||||||
def __init__(self, vocab_size, ctx_len, **kwargs):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
||||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
||||||
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
|
|
||||||
else:
|
|
||||||
self.att = RWKV_TimeMix(config, layer_id)
|
|
||||||
|
|
||||||
self.ffn = RWKV_ChannelMix(config, layer_id)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.ln1(x)
|
|
||||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.ffnPre(x) # better in some cases
|
|
||||||
else:
|
|
||||||
x = x + self.att(x)
|
|
||||||
x = self.ln2(x)
|
|
||||||
x = x + self.ffn(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.step = 0
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
|
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*[Block(config, i)
|
|
||||||
for i in range(config.n_layer)])
|
|
||||||
|
|
||||||
self.ln_out = nn.LayerNorm(config.n_embd)
|
|
||||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_q.scale_init = 0
|
|
||||||
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_k.scale_init = 0.1
|
|
||||||
self.register_buffer("copy_mask", torch.tril(
|
|
||||||
torch.ones(config.ctx_len, config.ctx_len)))
|
|
||||||
|
|
||||||
self.ctx_len = config.ctx_len
|
|
||||||
|
|
||||||
RWKV_Init(self, config)
|
|
||||||
|
|
||||||
logger.info("number of parameters: %e", sum(p.numel()
|
|
||||||
for p in self.parameters()))
|
|
||||||
|
|
||||||
def get_ctx_len(self):
|
|
||||||
return self.ctx_len
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
||||||
if isinstance(module, (nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=1e-5)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
|
|
||||||
def configure_optimizers(self, train_config):
|
|
||||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
||||||
decay = set()
|
|
||||||
no_decay = set()
|
|
||||||
|
|
||||||
for mn, m in self.named_modules(): # here we disable weight_decay
|
|
||||||
for pn, p in m.named_parameters():
|
|
||||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
|
||||||
no_decay.add(fpn)
|
|
||||||
|
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
||||||
inter_params = decay & no_decay
|
|
||||||
union_params = decay | no_decay
|
|
||||||
assert len(
|
|
||||||
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
|
||||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
|
||||||
% (str(param_dict.keys() - union_params), )
|
|
||||||
|
|
||||||
optim_groups = [
|
|
||||||
{"params": [param_dict[pn]
|
|
||||||
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def forward(self, idx, targets=None):
|
|
||||||
self.step += 1
|
|
||||||
B, T = idx.size()
|
|
||||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
|
||||||
x = self.emb(idx)
|
|
||||||
|
|
||||||
x = self.blocks(x)
|
|
||||||
|
|
||||||
x = self.ln_out(x)
|
|
||||||
|
|
||||||
q = self.head_q(x)[:, :T, :]
|
|
||||||
k = self.head_k(x)[:, :T, :]
|
|
||||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
|
||||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
||||||
|
|
||||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
|
|
||||||
x = self.head(x) + c
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if targets is not None:
|
|
||||||
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
|
|
||||||
|
|
||||||
return x, loss
|
|
||||||
@ -1,143 +0,0 @@
|
|||||||
import types
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
RWKV_K_CLAMP = 60
|
|
||||||
RWKV_K_EPS = 1e-16
|
|
||||||
RWKV_HEAD_QK_DIM = 256
|
|
||||||
|
|
||||||
DEBUG_TIME = False # True False - show trained time-coeffs
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_RNN():
|
|
||||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
|
||||||
self.RUN_DEVICE = RUN_DEVICE
|
|
||||||
self.model_type = model_type
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
|
|
||||||
self.w = types.SimpleNamespace()
|
|
||||||
|
|
||||||
w = torch.load(MODEL_NAME + '.pth',
|
|
||||||
map_location=torch.device(RUN_DEVICE))
|
|
||||||
for x in w.keys():
|
|
||||||
if '.time_' in x:
|
|
||||||
w[x] = w[x].squeeze()
|
|
||||||
if '.time_decay' in x:
|
|
||||||
w[x] = torch.exp(-torch.exp(w[x]))
|
|
||||||
if '.time_first' in x:
|
|
||||||
w[x] = torch.exp(w[x])
|
|
||||||
if DEBUG_TIME and '.time_' in x:
|
|
||||||
print(x, w[x].squeeze().cpu().numpy())
|
|
||||||
|
|
||||||
xx = x.split('.')
|
|
||||||
here = self.w
|
|
||||||
for i in range(len(xx)):
|
|
||||||
if xx[i].isdigit():
|
|
||||||
ii = int(xx[i])
|
|
||||||
if ii not in here:
|
|
||||||
here[ii] = types.SimpleNamespace()
|
|
||||||
here = here[ii]
|
|
||||||
else:
|
|
||||||
if i == len(xx) - 1:
|
|
||||||
setattr(here, xx[i], w[x])
|
|
||||||
elif not hasattr(here, xx[i]):
|
|
||||||
if xx[i+1].isdigit():
|
|
||||||
setattr(here, xx[i], {})
|
|
||||||
else:
|
|
||||||
setattr(here, xx[i], types.SimpleNamespace())
|
|
||||||
here = getattr(here, xx[i])
|
|
||||||
|
|
||||||
self.clear()
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self.xx = {}
|
|
||||||
self.aa = {}
|
|
||||||
self.bb = {}
|
|
||||||
self.hk = None
|
|
||||||
|
|
||||||
def save(self, target):
|
|
||||||
target.xx = copy.deepcopy(self.xx)
|
|
||||||
target.aa = copy.deepcopy(self.aa)
|
|
||||||
target.bb = copy.deepcopy(self.bb)
|
|
||||||
target.hk = copy.deepcopy(self.hk)
|
|
||||||
|
|
||||||
def load(self, target):
|
|
||||||
self.xx = copy.deepcopy(target.xx)
|
|
||||||
self.aa = copy.deepcopy(target.aa)
|
|
||||||
self.bb = copy.deepcopy(target.bb)
|
|
||||||
self.hk = copy.deepcopy(target.hk)
|
|
||||||
|
|
||||||
def LN(self, xx, w):
|
|
||||||
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
|
||||||
|
|
||||||
def FF(self, xx, w, name):
|
|
||||||
if name not in self.xx:
|
|
||||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
|
|
||||||
self.xx[name] = xx
|
|
||||||
|
|
||||||
r = torch.sigmoid(w.receptance.weight @ x)
|
|
||||||
k = torch.square(torch.relu(w.key.weight @ x))
|
|
||||||
kv = w.value.weight @ k
|
|
||||||
|
|
||||||
return r * kv
|
|
||||||
|
|
||||||
def SA(self, xx, w, name):
|
|
||||||
if name not in self.xx:
|
|
||||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)
|
|
||||||
self.xx[name] = xx
|
|
||||||
|
|
||||||
r = torch.sigmoid(w.receptance.weight @ x)
|
|
||||||
|
|
||||||
k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP))
|
|
||||||
v = w.value.weight @ x
|
|
||||||
kv = k * v
|
|
||||||
|
|
||||||
a = self.aa[name] + w.time_first * kv
|
|
||||||
b = self.bb[name] + w.time_first * k
|
|
||||||
self.aa[name] = w.time_decay * self.aa[name] + kv
|
|
||||||
self.bb[name] = w.time_decay * self.bb[name] + k
|
|
||||||
|
|
||||||
rwkv = r * a / (b + RWKV_K_EPS)
|
|
||||||
|
|
||||||
return w.output.weight @ rwkv
|
|
||||||
|
|
||||||
def run(self, ctx):
|
|
||||||
w = self.w
|
|
||||||
x = w.emb.weight[ctx[-1]]
|
|
||||||
|
|
||||||
for i in range(self.n_layer):
|
|
||||||
x = self.LN(x, w.blocks[i].ln1)
|
|
||||||
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}')
|
|
||||||
else:
|
|
||||||
x = x + self.SA(x, w.blocks[i].att, f'att.{i}')
|
|
||||||
x = self.LN(x, w.blocks[i].ln2)
|
|
||||||
x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')
|
|
||||||
|
|
||||||
x = self.LN(x, w.ln_out)
|
|
||||||
|
|
||||||
if self.hk == None:
|
|
||||||
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
|
||||||
else:
|
|
||||||
self.hk = torch.cat(
|
|
||||||
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
|
||||||
if self.hk.shape[0] > self.ctx_len:
|
|
||||||
self.hk = self.hk[-self.ctx_len:, :]
|
|
||||||
|
|
||||||
q = w.head_q.weight @ x
|
|
||||||
|
|
||||||
x = w.head.weight @ x
|
|
||||||
x = x.cpu().numpy().tolist()
|
|
||||||
|
|
||||||
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
|
||||||
for i in range(len(c)):
|
|
||||||
x[ctx[i]] += c[i]
|
|
||||||
|
|
||||||
return x
|
|
||||||
@ -1,170 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
from torch.utils.data.dataloader import DataLoader
|
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import datetime
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
|
|
||||||
# import wandb # comment this if you don't have wandb
|
|
||||||
# print('logging to wandb... (comment it if you don\'t have wandb)')
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
log_file = open("mylog.txt", "a")
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerConfig:
|
|
||||||
max_epochs = 10
|
|
||||||
batch_size = 64
|
|
||||||
learning_rate = 4e-4
|
|
||||||
betas = (0.9, 0.99)
|
|
||||||
eps = 1e-8
|
|
||||||
grad_norm_clip = 1.0
|
|
||||||
lr_decay = True # linear warmup followed by cosine decay
|
|
||||||
warmup_tokens = 0
|
|
||||||
final_tokens = 0
|
|
||||||
epoch_save_frequency = 0
|
|
||||||
epoch_save_path = 'trained-'
|
|
||||||
num_workers = 0 # for DataLoader
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
|
|
||||||
def __init__(self, model, train_dataset, test_dataset, config):
|
|
||||||
self.model = model
|
|
||||||
self.train_dataset = train_dataset
|
|
||||||
self.test_dataset = test_dataset
|
|
||||||
self.config = config
|
|
||||||
self.avg_loss = -1
|
|
||||||
self.steps = 0
|
|
||||||
|
|
||||||
if 'wandb' in sys.modules:
|
|
||||||
cfg = model.config
|
|
||||||
for k in config.__dict__:
|
|
||||||
setattr(cfg, k, config.__dict__[k]) # combine cfg
|
|
||||||
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
|
|
||||||
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
|
|
||||||
|
|
||||||
self.device = 'cpu'
|
|
||||||
if torch.cuda.is_available(): # take over whatever gpus are on the system
|
|
||||||
self.device = torch.cuda.current_device()
|
|
||||||
|
|
||||||
def get_run_name(self):
|
|
||||||
raw_model = self.model.module if hasattr(
|
|
||||||
self.model, "module") else self.model
|
|
||||||
cfg = raw_model.config
|
|
||||||
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
|
|
||||||
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
|
|
||||||
return run_name
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
model, config = self.model, self.config
|
|
||||||
raw_model = model.module if hasattr(self.model, "module") else model
|
|
||||||
optimizer = raw_model.configure_optimizers(config)
|
|
||||||
|
|
||||||
def run_epoch(split):
|
|
||||||
is_train = split == 'train'
|
|
||||||
model.train(is_train)
|
|
||||||
data = self.train_dataset if is_train else self.test_dataset
|
|
||||||
|
|
||||||
if config.num_workers > 0:
|
|
||||||
loader = DataLoader(data, shuffle=False, pin_memory=True,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
num_workers=config.num_workers)
|
|
||||||
else:
|
|
||||||
loader = DataLoader(data, shuffle=False,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
num_workers=config.num_workers)
|
|
||||||
|
|
||||||
pbar = tqdm(enumerate(loader), total=len(
|
|
||||||
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
|
||||||
|
|
||||||
for it, (x, y) in pbar:
|
|
||||||
x = x.to(self.device) # place data on the correct device
|
|
||||||
y = y.to(self.device)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train):
|
|
||||||
_, loss = model(x, y) # forward the model
|
|
||||||
|
|
||||||
if is_train: # backprop and update the parameters
|
|
||||||
model.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
if config.grad_norm_clip > 0:
|
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
model.parameters(), config.grad_norm_clip)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if config.lr_decay: # decay the learning rate based on our progress
|
|
||||||
# number of tokens processed this step (i.e. label is not -100)
|
|
||||||
self.tokens += (y >= 0).sum()
|
|
||||||
lr_final_factor = config.lr_final / config.learning_rate
|
|
||||||
if self.tokens < config.warmup_tokens:
|
|
||||||
# linear warmup
|
|
||||||
lr_mult = lr_final_factor + \
|
|
||||||
(1 - lr_final_factor) * float(self.tokens) / \
|
|
||||||
float(config.warmup_tokens)
|
|
||||||
progress = 0
|
|
||||||
else:
|
|
||||||
# cosine learning rate decay
|
|
||||||
progress = float(self.tokens - config.warmup_tokens) / float(
|
|
||||||
max(1, config.final_tokens - config.warmup_tokens))
|
|
||||||
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor /
|
|
||||||
2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1
|
|
||||||
lr = config.learning_rate * lr_mult
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = lr
|
|
||||||
else:
|
|
||||||
lr = config.learning_rate
|
|
||||||
|
|
||||||
now_loss = loss.item() # report progress
|
|
||||||
self.lr = lr
|
|
||||||
|
|
||||||
if 'wandb' in sys.modules:
|
|
||||||
wandb.log({"loss": now_loss},
|
|
||||||
step=self.steps * self.config.batch_size)
|
|
||||||
self.steps += 1
|
|
||||||
|
|
||||||
if self.avg_loss < 0:
|
|
||||||
self.avg_loss = now_loss
|
|
||||||
else:
|
|
||||||
factor = 1 / (it + 1)
|
|
||||||
self.avg_loss = self.avg_loss * \
|
|
||||||
(1.0 - factor) + now_loss * factor
|
|
||||||
pbar.set_description(
|
|
||||||
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
|
|
||||||
|
|
||||||
self.tokens = 0 # counter used for learning rate decay
|
|
||||||
for epoch in range(config.max_epochs):
|
|
||||||
|
|
||||||
run_epoch('train')
|
|
||||||
|
|
||||||
log_file.write(
|
|
||||||
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
|
|
||||||
log_file.flush()
|
|
||||||
|
|
||||||
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
|
|
||||||
# DataParallel wrappers keep raw model object in .module
|
|
||||||
raw_model = self.model.module if hasattr(
|
|
||||||
self.model, "module") else self.model
|
|
||||||
torch.save(raw_model.state_dict(),
|
|
||||||
self.config.epoch_save_path + str(epoch+1) + '.pth')
|
|
||||||
@ -1,122 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(Dataset):
|
|
||||||
def __init__(self, data, ctx_len, epoch_length_fixed):
|
|
||||||
print('building token list...', end=' ')
|
|
||||||
unique = sorted(list(set(data)))
|
|
||||||
# print()
|
|
||||||
# for u in unique:
|
|
||||||
# print(u, end=' ')
|
|
||||||
# print('\n\n')
|
|
||||||
|
|
||||||
xx = 0
|
|
||||||
xxObj = {}
|
|
||||||
for u in unique:
|
|
||||||
xxObj[xx] = u
|
|
||||||
xx += 1
|
|
||||||
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
|
|
||||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
|
||||||
|
|
||||||
data_size, vocab_size = len(data), len(unique)
|
|
||||||
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
|
|
||||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
|
||||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
self.epoch_length_fixed = epoch_length_fixed
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.epoch_length_fixed
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
# cheat: pick a random spot in dataset
|
|
||||||
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
|
|
||||||
chunk = self.data[i:i+self.ctx_len+1]
|
|
||||||
dix = [self.stoi[s] for s in chunk]
|
|
||||||
x = torch.tensor(dix[:-1], dtype=torch.long,
|
|
||||||
device=torch.device('cuda'))
|
|
||||||
y = torch.tensor(dix[1:], dtype=torch.long,
|
|
||||||
device=torch.device('cuda'))
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
|
|
||||||
class TOKENIZER():
|
|
||||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
|
||||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
|
||||||
self.word_table = json.load(result_file)
|
|
||||||
|
|
||||||
self.vocab_size = len(self.word_table)
|
|
||||||
|
|
||||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
|
||||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
|
||||||
|
|
||||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
|
||||||
|
|
||||||
def refine_context(self, context):
|
|
||||||
context = context.strip().split('\n')
|
|
||||||
for c in range(len(context)):
|
|
||||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
|
||||||
context = list(filter(lambda c: c != '', context))
|
|
||||||
context = '\n' + ('\n'.join(context)).strip()
|
|
||||||
if context == '':
|
|
||||||
context = '\n'
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
||||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
|
||||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
|
||||||
|
|
||||||
lastChar = int(x[-1])
|
|
||||||
|
|
||||||
probs = F.softmax(torch.tensor(out), dim=-1)
|
|
||||||
|
|
||||||
if self.itos[lastChar] == '\n':
|
|
||||||
top_p = top_p_newline
|
|
||||||
else:
|
|
||||||
top_p = top_p_usual
|
|
||||||
|
|
||||||
sorted_probs, s_index = torch.sort(probs, descending=True)
|
|
||||||
|
|
||||||
# for j in range(30):
|
|
||||||
# pp = sorted_probs[j].item()
|
|
||||||
# if pp < 0.005:
|
|
||||||
# break
|
|
||||||
# ss = self.itos[int(s_index[j])].replace('\n','_')
|
|
||||||
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
|
|
||||||
# print('')
|
|
||||||
|
|
||||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
|
|
||||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
|
||||||
|
|
||||||
probs[probs < cutoff] = 0
|
|
||||||
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
|
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
probs = probs.pow(1.0 / temperature)
|
|
||||||
|
|
||||||
return torch.multinomial(probs, num_samples=1)[0]
|
|
||||||
|
|
||||||
|
|
||||||
def to_float(x):
|
|
||||||
return x.cpu().detach().numpy().flatten()[0].astype(float)
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
from src.model import GPT, GPTConfig
|
|
||||||
from src.trainer import Trainer, TrainerConfig
|
|
||||||
from src.utils import Dataset
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
### Step 1: set training data ##########################################################################
|
|
||||||
|
|
||||||
datafile = "enwik8"
|
|
||||||
datafile_encoding = 'utf-8'
|
|
||||||
# datafile_encoding = 'utf-16le'
|
|
||||||
|
|
||||||
### Step 2: set model size #############################################################################
|
|
||||||
|
|
||||||
ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
|
|
||||||
# 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases)
|
|
||||||
model_type = 'RWKV'
|
|
||||||
|
|
||||||
### Step 3: set batch size #############################################################################
|
|
||||||
|
|
||||||
# ===> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
|
|
||||||
# For example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
|
|
||||||
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
|
|
||||||
batch_size = 12
|
|
||||||
|
|
||||||
### Step 4: set learning rate, training mini-epochs #######################################################
|
|
||||||
|
|
||||||
lr_init = 6e-4
|
|
||||||
lr_final = 1e-5
|
|
||||||
# the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
|
|
||||||
n_epoch = 500
|
|
||||||
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc.
|
|
||||||
epoch_save_frequency = 30
|
|
||||||
epoch_save_path = 'trained-'
|
|
||||||
|
|
||||||
epoch_length_fixed = 10000
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
# import src.utils
|
|
||||||
# src.utils.set_seed(42) # remember to change seed if you load a model
|
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
|
|
||||||
|
|
||||||
grad_norm_clip = 1.0
|
|
||||||
warmup_tokens = 0
|
|
||||||
|
|
||||||
betas = (0.9, 0.99)
|
|
||||||
eps = 4e-9
|
|
||||||
|
|
||||||
num_workers = 0
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Load data
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print('loading data... ' + datafile)
|
|
||||||
train_dataset = Dataset(open(
|
|
||||||
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Train model
|
|
||||||
########################################################################################################
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
|
|
||||||
n_layer=n_layer, n_embd=n_embd)).cuda()
|
|
||||||
|
|
||||||
# # # load a trained model. remember to change random seed
|
|
||||||
# m2 = torch.load('trained-61.pth')
|
|
||||||
# model.load_state_dict(m2)
|
|
||||||
|
|
||||||
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
|
|
||||||
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
|
|
||||||
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
|
|
||||||
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
|
|
||||||
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
|
|
||||||
trainer = Trainer(model, train_dataset, None, tconf)
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
|
|
||||||
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
|
|
||||||
|
Before Width: | Height: | Size: 121 KiB |
|
Before Width: | Height: | Size: 321 KiB |
@ -1,172 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)
|
|
||||||
|
|
||||||
#define F4(A, B) ((float4 *)(A))[(B) >> 2]
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,
|
|
||||||
const F eps, const int B, const int C, const int T) {
|
|
||||||
const int i = blockIdx.y;
|
|
||||||
const int ij = (B * C) / BF;
|
|
||||||
const int t = threadIdx.x << 2;
|
|
||||||
|
|
||||||
__shared__ F ww[Tmax];
|
|
||||||
__shared__ F kk[Tmax * BF];
|
|
||||||
F4(ww, t) = F4(__w, t + T * (i % C));
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float4 s[BF];
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
s[j] = {eps, eps, eps, eps};
|
|
||||||
}
|
|
||||||
const F *__restrict__ const w = ww + T - t - 4;
|
|
||||||
for (int u = 0; u <= t; u++) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
const F x = kk[u + Tmax * j];
|
|
||||||
s[j].x += w[u + 3] * x;
|
|
||||||
s[j].y += w[u + 2] * x;
|
|
||||||
s[j].z += w[u + 1] * x;
|
|
||||||
s[j].w += w[u + 0] * x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BF; j++) {
|
|
||||||
const F *__restrict__ const k = kk + Tmax * j;
|
|
||||||
s[j].y += w[t + 3] * k[t + 1];
|
|
||||||
s[j].z += w[t + 2] * k[t + 1];
|
|
||||||
s[j].z += w[t + 3] * k[t + 2];
|
|
||||||
s[j].w += w[t + 1] * k[t + 1];
|
|
||||||
s[j].w += w[t + 2] * k[t + 2];
|
|
||||||
s[j].w += w[t + 3] * k[t + 3];
|
|
||||||
F4(x, t + T * (i + ij * j)) = s[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
|
|
||||||
F *__restrict__ const gw, F *__restrict__ const gk,
|
|
||||||
const int B, const int C, const int T) {
|
|
||||||
const int i = blockIdx.y;
|
|
||||||
const int t = threadIdx.x << 2;
|
|
||||||
|
|
||||||
__shared__ F k[Tmax];
|
|
||||||
__shared__ F gg[Tmax];
|
|
||||||
F4(k, t) = F4(__k, t + T * i);
|
|
||||||
F4(gg, t) = F4(__gwk, t + T * i);
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float4 s = {0, 0, 0, 0};
|
|
||||||
|
|
||||||
const F *__restrict__ const g = gg + T - t - 4;
|
|
||||||
for (int u = 0; u <= t; u++) {
|
|
||||||
F x = k[u];
|
|
||||||
s.x += g[u + 3] * x;
|
|
||||||
s.y += g[u + 2] * x;
|
|
||||||
s.z += g[u + 1] * x;
|
|
||||||
s.w += g[u + 0] * x;
|
|
||||||
}
|
|
||||||
s.y += g[t + 3] * k[t + 1];
|
|
||||||
s.z += g[t + 2] * k[t + 1];
|
|
||||||
s.z += g[t + 3] * k[t + 2];
|
|
||||||
s.w += g[t + 1] * k[t + 1];
|
|
||||||
s.w += g[t + 2] * k[t + 2];
|
|
||||||
s.w += g[t + 3] * k[t + 3];
|
|
||||||
F4(gw, t + T * i) = s;
|
|
||||||
}
|
|
||||||
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) {
|
|
||||||
dim3 gridDim(1, B * C / BF);
|
|
||||||
dim3 blockDim(T >> 2);
|
|
||||||
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk,
|
|
||||||
F *__restrict__ const gw, F *__restrict__ const gk,
|
|
||||||
const int B, const int C, const int T) {
|
|
||||||
const int i = blockIdx.y;
|
|
||||||
const int ij = (B * C) / BB;
|
|
||||||
const int t = threadIdx.x << 2;
|
|
||||||
|
|
||||||
__shared__ F w[Tmax];
|
|
||||||
__shared__ F kk[Tmax * BB];
|
|
||||||
__shared__ F gg[Tmax * BB];
|
|
||||||
F4(w, t) = F4(__w, t + T * (i % C));
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));
|
|
||||||
F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j));
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float4 s[BB];
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
s[j] = {0, 0, 0, 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int u = 0; u <= t; u++) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
|
|
||||||
F x = kk[u + Tmax * j];
|
|
||||||
s[j].x += g[u + 3] * x;
|
|
||||||
s[j].y += g[u + 2] * x;
|
|
||||||
s[j].z += g[u + 1] * x;
|
|
||||||
s[j].w += g[u + 0] * x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const k = kk + Tmax * j;
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T - t - 4;
|
|
||||||
s[j].y += g[t + 3] * k[t + 1];
|
|
||||||
s[j].z += g[t + 2] * k[t + 1];
|
|
||||||
s[j].z += g[t + 3] * k[t + 2];
|
|
||||||
s[j].w += g[t + 1] * k[t + 1];
|
|
||||||
s[j].w += g[t + 2] * k[t + 2];
|
|
||||||
s[j].w += g[t + 3] * k[t + 3];
|
|
||||||
F4(gw, t + T * (i + ij * j)) = s[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
s[j] = {0, 0, 0, 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int u = t + 3; u < T; u++) {
|
|
||||||
F x = w[u];
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
|
|
||||||
s[j].x += g[2 - u] * x;
|
|
||||||
s[j].y += g[3 - u] * x;
|
|
||||||
s[j].z += g[4 - u] * x;
|
|
||||||
s[j].w += g[5 - u] * x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < BB; j++) {
|
|
||||||
const F *__restrict__ const g = gg + Tmax * j + T + t - 3;
|
|
||||||
s[j].x += g[2 - t] * w[t + 0];
|
|
||||||
s[j].x += g[1 - t] * w[t + 1];
|
|
||||||
s[j].x += g[0 - t] * w[t + 2];
|
|
||||||
s[j].y += g[2 - t] * w[t + 1];
|
|
||||||
s[j].y += g[1 - t] * w[t + 2];
|
|
||||||
s[j].z += g[2 - t] * w[t + 2];
|
|
||||||
F4(gk, t + T * (i + ij * j)) = s[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) {
|
|
||||||
dim3 gridDim(1, B * C / BB);
|
|
||||||
dim3 blockDim(T >> 2);
|
|
||||||
kernel_backward<<<gridDim, blockDim>>>(w, k, gwk, gw, gk, B, C, T);
|
|
||||||
}
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T);
|
|
||||||
void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T);
|
|
||||||
|
|
||||||
void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) {
|
|
||||||
cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T);
|
|
||||||
}
|
|
||||||
void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) {
|
|
||||||
cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T);
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &forward, "timex forward");
|
|
||||||
m.def("backward", &backward, "timex backward");
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY(timex, m) {
|
|
||||||
m.def("forward", forward);
|
|
||||||
m.def("backward", backward);
|
|
||||||
}
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from src.utils import TOKENIZER, Dataset
|
|
||||||
from src.model_run import RWKV_RNN
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
|
|
||||||
### Step 1: set model ##################################################################################
|
|
||||||
|
|
||||||
ctx_len = 1024
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
|
|
||||||
|
|
||||||
# your trained model
|
|
||||||
MODEL_NAME = 'trained-1'
|
|
||||||
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
|
|
||||||
|
|
||||||
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
|
|
||||||
# --> all unknown tokens in your context will be denoted by it <--
|
|
||||||
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
|
|
||||||
|
|
||||||
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
|
|
||||||
DEBUG_DEBUG = False # True False - show softmax output
|
|
||||||
|
|
||||||
### Step 2: set context ################################################################################
|
|
||||||
|
|
||||||
context = "\nIn the" # ==> this is your prompt
|
|
||||||
|
|
||||||
NUM_TRIALS = 999
|
|
||||||
LENGTH_PER_TRIAL = 500
|
|
||||||
|
|
||||||
TEMPERATURE = 1.0
|
|
||||||
top_p = 0.7
|
|
||||||
top_p_newline = 0.9
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print(f'Loading {MODEL_NAME}...')
|
|
||||||
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
|
|
||||||
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
context = tokenizer.refine_context(context)
|
|
||||||
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
|
|
||||||
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
|
|
||||||
|
|
||||||
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
|
||||||
t_begin = time.time_ns()
|
|
||||||
|
|
||||||
src_len = len(context)
|
|
||||||
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
||||||
print(('-' * 30) + context, end='')
|
|
||||||
|
|
||||||
model.clear()
|
|
||||||
if TRIAL == 0:
|
|
||||||
init_state = types.SimpleNamespace()
|
|
||||||
for i in range(src_len):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
if i == src_len - 1:
|
|
||||||
init_state.out = model.run(x)
|
|
||||||
else:
|
|
||||||
model.run(x)
|
|
||||||
model.save(init_state)
|
|
||||||
else:
|
|
||||||
model.load(init_state)
|
|
||||||
|
|
||||||
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
x = x[-ctx_len:]
|
|
||||||
|
|
||||||
if i == src_len:
|
|
||||||
out = copy.deepcopy(init_state.out)
|
|
||||||
else:
|
|
||||||
out = model.run(x)
|
|
||||||
if DEBUG_DEBUG:
|
|
||||||
print('model', np.array(x), '==>', np.array(
|
|
||||||
out), np.max(out), np.min(out))
|
|
||||||
|
|
||||||
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
|
|
||||||
top_p_usual=top_p, top_p_newline=top_p_newline)
|
|
||||||
char = char.item()
|
|
||||||
print(tokenizer.itos[int(char)], end='', flush=True)
|
|
||||||
ctx += [char]
|
|
||||||
t_end = time.time_ns()
|
|
||||||
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')
|
|
||||||
@ -1,363 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
RWKV_K_CLAMP = 60 # e^60 = 1e26
|
|
||||||
RWKV_K_EPS = 1e-8
|
|
||||||
RWKV_HEAD_QK_DIM = 256
|
|
||||||
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# CUDA Kernel
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
T_MAX = 1024 # increase this if your ctx_len > 1024
|
|
||||||
B_GROUP_FORWARD = 4 # set to 8 for best performance
|
|
||||||
B_GROUP_BACKWARD = 2 # set to 2 for best performance (sometimes 8 is faster)
|
|
||||||
|
|
||||||
timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"],
|
|
||||||
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])
|
|
||||||
|
|
||||||
|
|
||||||
class TimeX(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, w, k, B, C, T, eps):
|
|
||||||
ctx.B = B
|
|
||||||
ctx.C = C
|
|
||||||
ctx.T = T
|
|
||||||
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
|
|
||||||
w = w.contiguous()
|
|
||||||
k = k.contiguous()
|
|
||||||
ctx.save_for_backward(w, k)
|
|
||||||
wk = torch.empty((B, C, T), device='cuda',
|
|
||||||
memory_format=torch.contiguous_format)
|
|
||||||
timex_cuda.forward(w, k, wk, eps, B, C, T)
|
|
||||||
return wk
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gwk):
|
|
||||||
assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
|
|
||||||
w, k = ctx.saved_tensors
|
|
||||||
gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
|
|
||||||
memory_format=torch.contiguous_format)
|
|
||||||
gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
|
|
||||||
memory_format=torch.contiguous_format)
|
|
||||||
timex_cuda.backward(w, k, gwk.contiguous(), gw,
|
|
||||||
gk, ctx.B, ctx.C, ctx.T)
|
|
||||||
return (gw.sum(dim=0), gk, None, None, None, None)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module
|
|
||||||
for m in module.modules():
|
|
||||||
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
||||||
continue
|
|
||||||
with torch.no_grad():
|
|
||||||
name = '[unknown weight]'
|
|
||||||
for name, parameter in module.named_parameters(): # find the name of the weight
|
|
||||||
if id(m.weight) == id(parameter):
|
|
||||||
break
|
|
||||||
|
|
||||||
shape = m.weight.data.shape
|
|
||||||
gain = 1.0
|
|
||||||
scale = 1.0 # extra scale for gain
|
|
||||||
|
|
||||||
if isinstance(m, nn.Embedding):
|
|
||||||
gain = math.sqrt(max(shape[0], shape[1]))
|
|
||||||
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb?
|
|
||||||
scale = 1e-4
|
|
||||||
else:
|
|
||||||
scale = 0
|
|
||||||
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
if shape[0] > shape[1]:
|
|
||||||
gain = math.sqrt(shape[0] / shape[1])
|
|
||||||
if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection?
|
|
||||||
scale = 0.5
|
|
||||||
|
|
||||||
if hasattr(m, 'scale_init'):
|
|
||||||
scale = m.scale_init
|
|
||||||
|
|
||||||
# print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
|
|
||||||
|
|
||||||
gain *= scale
|
|
||||||
if scale == -999:
|
|
||||||
nn.init.eye_(m.weight)
|
|
||||||
elif gain == 0:
|
|
||||||
# zero init is great for some RWKV matrices
|
|
||||||
nn.init.zeros_(m.weight)
|
|
||||||
elif gain > 0:
|
|
||||||
nn.init.orthogonal_(m.weight, gain=gain)
|
|
||||||
else:
|
|
||||||
nn.init.normal_(m.weight, mean=0.0, std=-scale)
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_TimeMix(nn.Module):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.ctx_len = config.ctx_len
|
|
||||||
self.n_embd = config.n_embd
|
|
||||||
|
|
||||||
attn_sz = config.n_embd
|
|
||||||
|
|
||||||
with torch.no_grad(): # fancy init
|
|
||||||
self.time_curve = torch.tensor([-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
|
|
||||||
self.time_curve = self.time_curve.to('cuda')
|
|
||||||
|
|
||||||
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
|
|
||||||
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
|
||||||
|
|
||||||
# fancy time_decay
|
|
||||||
decay_speed = torch.ones(attn_sz, 1)
|
|
||||||
for h in range(attn_sz):
|
|
||||||
decay_speed[h][0] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
|
||||||
self.time_decay = nn.Parameter(decay_speed)
|
|
||||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
|
||||||
|
|
||||||
# fancy time_first
|
|
||||||
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5).unsqueeze(1)
|
|
||||||
self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3) + zigzag)
|
|
||||||
|
|
||||||
# fancy time_mix
|
|
||||||
x = torch.ones(1, 1, config.n_embd)
|
|
||||||
for i in range(config.n_embd):
|
|
||||||
x[0, 0, i] = i / config.n_embd
|
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
|
||||||
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
|
|
||||||
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
|
|
||||||
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.key.scale_init = 0
|
|
||||||
self.receptance.scale_init = 0
|
|
||||||
self.output.scale_init = 0
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
||||||
|
|
||||||
# Mix x with the previous timestep to produce xk, xv, xr
|
|
||||||
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
# Use xk, xv, xr to produce k, v, r
|
|
||||||
k = self.key(xk).transpose(-1, -2)
|
|
||||||
v = self.value(xv).transpose(-1, -2)
|
|
||||||
r = self.receptance(xr)
|
|
||||||
|
|
||||||
# RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
|
|
||||||
k = torch.clamp(k, max=RWKV_K_CLAMP) # clamp k to avoid overflow
|
|
||||||
k = torch.exp(k)
|
|
||||||
kv = k * v
|
|
||||||
|
|
||||||
# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
|
|
||||||
self.time_w = torch.cat(
|
|
||||||
[torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
|
|
||||||
w = torch.exp(self.time_w)
|
|
||||||
|
|
||||||
# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
|
|
||||||
wkv = TimeX.apply(w, kv, B, C, T, 0)
|
|
||||||
# RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
|
|
||||||
wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)
|
|
||||||
|
|
||||||
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
|
|
||||||
rwkv = self.output(rwkv)
|
|
||||||
return rwkv
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_ChannelMix(nn.Module):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
|
|
||||||
with torch.no_grad(): # fancy init of time_mix
|
|
||||||
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
|
||||||
|
|
||||||
x = torch.ones(1, 1, config.n_embd)
|
|
||||||
for i in range(config.n_embd):
|
|
||||||
x[0, 0, i] = i / config.n_embd
|
|
||||||
|
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
|
|
||||||
hidden_sz = 4 * config.n_embd
|
|
||||||
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.value.scale_init = 0
|
|
||||||
self.receptance.scale_init = 0
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
k = self.key(xk)
|
|
||||||
k = torch.square(torch.relu(k))
|
|
||||||
kv = self.value(k)
|
|
||||||
|
|
||||||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
|
||||||
return rkv
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# The GPT Model with our blocks
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
class GPTConfig:
|
|
||||||
def __init__(self, vocab_size, ctx_len, **kwargs):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
||||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0:
|
|
||||||
self.ln0 = nn.LayerNorm(config.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
||||||
self.ffnPre = RWKV_ChannelMix(config, layer_id+1000)
|
|
||||||
else:
|
|
||||||
self.att = RWKV_TimeMix(config, layer_id)
|
|
||||||
|
|
||||||
self.ffn = RWKV_ChannelMix(config, layer_id)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.layer_id == 0:
|
|
||||||
x = self.ln0(x)
|
|
||||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.ffnPre(self.ln1(x)) # better in some cases
|
|
||||||
else:
|
|
||||||
x = x + self.att(self.ln1(x))
|
|
||||||
x = x + self.ffn(self.ln2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.step = 0
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
|
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*[Block(config, i)
|
|
||||||
for i in range(config.n_layer)])
|
|
||||||
|
|
||||||
self.ln_out = nn.LayerNorm(config.n_embd)
|
|
||||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_q.scale_init = 0
|
|
||||||
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_k.scale_init = 0.1
|
|
||||||
self.register_buffer("copy_mask", torch.tril(
|
|
||||||
torch.ones(config.ctx_len, config.ctx_len)))
|
|
||||||
|
|
||||||
self.ctx_len = config.ctx_len
|
|
||||||
|
|
||||||
RWKV_Init(self, config)
|
|
||||||
|
|
||||||
logger.info("number of parameters: %e", sum(p.numel()
|
|
||||||
for p in self.parameters()))
|
|
||||||
|
|
||||||
def get_ctx_len(self):
|
|
||||||
return self.ctx_len
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
||||||
if isinstance(module, (nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=1e-5)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
|
|
||||||
def configure_optimizers(self, train_config):
|
|
||||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
||||||
decay = set()
|
|
||||||
no_decay = set()
|
|
||||||
|
|
||||||
for mn, m in self.named_modules(): # here we disable weight_decay
|
|
||||||
for pn, p in m.named_parameters():
|
|
||||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
|
||||||
no_decay.add(fpn)
|
|
||||||
|
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
||||||
inter_params = decay & no_decay
|
|
||||||
union_params = decay | no_decay
|
|
||||||
assert len(
|
|
||||||
inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
|
||||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
|
||||||
% (str(param_dict.keys() - union_params), )
|
|
||||||
|
|
||||||
optim_groups = [
|
|
||||||
{"params": [param_dict[pn]
|
|
||||||
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def forward(self, idx, targets=None):
|
|
||||||
self.step += 1
|
|
||||||
B, T = idx.size()
|
|
||||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
|
||||||
x = self.emb(idx)
|
|
||||||
|
|
||||||
x = self.blocks(x)
|
|
||||||
|
|
||||||
x = self.ln_out(x)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
q = self.head_q(x)[:, :T, :]
|
|
||||||
k = self.head_k(x)[:, :T, :]
|
|
||||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
|
||||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
||||||
|
|
||||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float()
|
|
||||||
x = self.head(x) + c
|
|
||||||
else:
|
|
||||||
x = self.head(x)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if targets is not None:
|
|
||||||
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
|
|
||||||
|
|
||||||
return x, loss
|
|
||||||
@ -1,319 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import types
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
RWKV_K_CLAMP = 60
|
|
||||||
RWKV_K_EPS = 1e-8
|
|
||||||
RWKV_HEAD_QK_DIM = 256
|
|
||||||
print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
||||||
|
|
||||||
DEBUG_TIME = False # True False - show trained time-coeffs
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
|
|
||||||
RWKV_CFG = types.SimpleNamespace()
|
|
||||||
|
|
||||||
class RWKV_ChannelMix(nn.Module):
|
|
||||||
def __init__(self, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
||||||
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
|
||||||
|
|
||||||
hidden_sz = 4 * RWKV_CFG.n_embd
|
|
||||||
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
k = self.key(xk)
|
|
||||||
k = torch.square(torch.relu(k))
|
|
||||||
kv = self.value(k)
|
|
||||||
|
|
||||||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
|
||||||
return rkv
|
|
||||||
|
|
||||||
class RWKV_TimeMix(nn.Module):
|
|
||||||
def __init__(self, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1))
|
|
||||||
self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0)
|
|
||||||
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3))
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
||||||
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
|
||||||
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
|
||||||
|
|
||||||
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T, C = x.size()
|
|
||||||
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
k = self.key(xk).transpose(-1, -2)
|
|
||||||
v = self.value(xv).transpose(-1, -2)
|
|
||||||
r = self.receptance(xr)
|
|
||||||
|
|
||||||
k = torch.clamp(k, max=RWKV_K_CLAMP)
|
|
||||||
k = torch.exp(k)
|
|
||||||
|
|
||||||
kv = k * v
|
|
||||||
|
|
||||||
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
|
|
||||||
w = torch.exp(self.time_w)
|
|
||||||
|
|
||||||
w = w[:,-T:].unsqueeze(1)
|
|
||||||
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
|
|
||||||
wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS
|
|
||||||
|
|
||||||
rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
|
|
||||||
|
|
||||||
rwkv = self.output(rwkv)
|
|
||||||
return rwkv
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
|
|
||||||
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
|
|
||||||
if self.layer_id == 0:
|
|
||||||
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
|
||||||
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
|
|
||||||
else:
|
|
||||||
self.att = RWKV_TimeMix(layer_id)
|
|
||||||
|
|
||||||
self.ffn = RWKV_ChannelMix(layer_id)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.layer_id == 0:
|
|
||||||
x = self.ln0(x)
|
|
||||||
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.ffnPre(self.ln1(x))
|
|
||||||
else:
|
|
||||||
x = x + self.att(self.ln1(x))
|
|
||||||
x = x + self.ffn(self.ln2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class RWKV_GPT(nn.Module):
|
|
||||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
|
|
||||||
global RWKV_CFG
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
|
|
||||||
RWKV_CFG.model_type = model_type
|
|
||||||
RWKV_CFG.vocab_size = vocab_size
|
|
||||||
RWKV_CFG.n_layer = n_layer
|
|
||||||
RWKV_CFG.n_embd = n_embd
|
|
||||||
RWKV_CFG.ctx_len = ctx_len
|
|
||||||
|
|
||||||
print('\nloading RWKV-GPT', MODEL_NAME)
|
|
||||||
|
|
||||||
self.emb = nn.Embedding(vocab_size, n_embd)
|
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
|
|
||||||
|
|
||||||
self.ln_out = nn.LayerNorm(n_embd)
|
|
||||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_q.scale_init = 0
|
|
||||||
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_k.scale_init = 0.1
|
|
||||||
self.register_buffer("copy_mask", torch.tril(
|
|
||||||
torch.ones(ctx_len, ctx_len)))
|
|
||||||
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
self.eval()
|
|
||||||
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, idx):
|
|
||||||
B, T = idx.size()
|
|
||||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
|
||||||
|
|
||||||
x = self.emb(idx)
|
|
||||||
x = self.blocks(x)
|
|
||||||
x = self.ln_out(x)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
q = self.head_q(x)[:, :T, :]
|
|
||||||
k = self.head_k(x)[:, :T, :]
|
|
||||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
|
||||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
||||||
|
|
||||||
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
|
|
||||||
x = self.head(x) + c
|
|
||||||
else:
|
|
||||||
x = self.head(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
|
|
||||||
class RWKV_RNN():
|
|
||||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
|
||||||
self.RUN_DEVICE = RUN_DEVICE
|
|
||||||
self.model_type = model_type
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
|
|
||||||
self.w = types.SimpleNamespace()
|
|
||||||
|
|
||||||
w = torch.load(MODEL_NAME + '.pth',
|
|
||||||
map_location=torch.device(RUN_DEVICE))
|
|
||||||
for x in w.keys():
|
|
||||||
if '.time_' in x:
|
|
||||||
w[x] = w[x].squeeze()
|
|
||||||
if '.time_decay' in x:
|
|
||||||
w[x] = torch.exp(-torch.exp(w[x]))
|
|
||||||
if '.time_first' in x:
|
|
||||||
w[x] = torch.exp(w[x])
|
|
||||||
if DEBUG_TIME and '.time_' in x:
|
|
||||||
print(x, w[x].squeeze().cpu().numpy())
|
|
||||||
|
|
||||||
xx = x.split('.')
|
|
||||||
here = self.w
|
|
||||||
for i in range(len(xx)):
|
|
||||||
if xx[i].isdigit():
|
|
||||||
ii = int(xx[i])
|
|
||||||
if ii not in here:
|
|
||||||
here[ii] = types.SimpleNamespace()
|
|
||||||
here = here[ii]
|
|
||||||
else:
|
|
||||||
if i == len(xx) - 1:
|
|
||||||
setattr(here, xx[i], w[x])
|
|
||||||
elif not hasattr(here, xx[i]):
|
|
||||||
if xx[i+1].isdigit():
|
|
||||||
setattr(here, xx[i], {})
|
|
||||||
else:
|
|
||||||
setattr(here, xx[i], types.SimpleNamespace())
|
|
||||||
here = getattr(here, xx[i])
|
|
||||||
|
|
||||||
self.clear()
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self.xx = {}
|
|
||||||
self.aa = {}
|
|
||||||
self.bb = {}
|
|
||||||
self.hk = None
|
|
||||||
|
|
||||||
def save(self, target):
|
|
||||||
target.xx = copy.deepcopy(self.xx)
|
|
||||||
target.aa = copy.deepcopy(self.aa)
|
|
||||||
target.bb = copy.deepcopy(self.bb)
|
|
||||||
target.hk = copy.deepcopy(self.hk)
|
|
||||||
|
|
||||||
def load(self, target):
|
|
||||||
self.xx = copy.deepcopy(target.xx)
|
|
||||||
self.aa = copy.deepcopy(target.aa)
|
|
||||||
self.bb = copy.deepcopy(target.bb)
|
|
||||||
self.hk = copy.deepcopy(target.hk)
|
|
||||||
|
|
||||||
def LN(self, xx, w):
|
|
||||||
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
|
||||||
|
|
||||||
def FF(self, xx, w, name):
|
|
||||||
if name not in self.xx:
|
|
||||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
|
||||||
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
|
||||||
self.xx[name] = xx
|
|
||||||
|
|
||||||
r = torch.sigmoid(w.receptance.weight @ xr)
|
|
||||||
k = torch.square(torch.relu(w.key.weight @ xk))
|
|
||||||
kv = w.value.weight @ k
|
|
||||||
|
|
||||||
return r * kv
|
|
||||||
|
|
||||||
def SA(self, xx, w, name):
|
|
||||||
if name not in self.xx:
|
|
||||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
|
|
||||||
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
|
||||||
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
|
|
||||||
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
|
||||||
self.xx[name] = xx
|
|
||||||
|
|
||||||
r = torch.sigmoid(w.receptance.weight @ xr)
|
|
||||||
|
|
||||||
k = torch.exp(torch.clamp(w.key.weight @ xk, max=RWKV_K_CLAMP))
|
|
||||||
v = w.value.weight @ xv
|
|
||||||
kv = k * v
|
|
||||||
|
|
||||||
a = self.aa[name] + w.time_first * kv
|
|
||||||
b = self.bb[name] + w.time_first * k
|
|
||||||
self.aa[name] = w.time_decay * self.aa[name] + kv
|
|
||||||
self.bb[name] = w.time_decay * self.bb[name] + k
|
|
||||||
|
|
||||||
rwkv = r * a / (b + RWKV_K_EPS)
|
|
||||||
|
|
||||||
return w.output.weight @ rwkv
|
|
||||||
|
|
||||||
def run(self, ctx):
|
|
||||||
w = self.w
|
|
||||||
x = w.emb.weight[ctx[-1]]
|
|
||||||
|
|
||||||
for i in range(self.n_layer):
|
|
||||||
if i == 0:
|
|
||||||
x = self.LN(x, w.blocks[i].ln0)
|
|
||||||
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
|
|
||||||
else:
|
|
||||||
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
|
|
||||||
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
|
|
||||||
|
|
||||||
x = self.LN(x, w.ln_out)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
if self.hk == None:
|
|
||||||
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
|
||||||
else:
|
|
||||||
self.hk = torch.cat(
|
|
||||||
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
|
||||||
if self.hk.shape[0] > self.ctx_len:
|
|
||||||
self.hk = self.hk[-self.ctx_len:, :]
|
|
||||||
|
|
||||||
q = w.head_q.weight @ x
|
|
||||||
|
|
||||||
x = w.head.weight @ x
|
|
||||||
x = x.cpu().numpy().tolist()
|
|
||||||
|
|
||||||
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
|
||||||
for i in range(len(c)):
|
|
||||||
x[ctx[i]] += c[i]
|
|
||||||
else:
|
|
||||||
x = w.head.weight @ x
|
|
||||||
x = x.cpu().numpy().tolist()
|
|
||||||
|
|
||||||
return x
|
|
||||||
@ -1,171 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
from torch.utils.data.dataloader import DataLoader
|
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import datetime
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
|
|
||||||
# import wandb # comment this if you don't have wandb
|
|
||||||
# print('logging to wandb... (comment it if you don\'t have wandb)')
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
log_file = open("mylog.txt", "a")
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerConfig:
|
|
||||||
max_epochs = 10
|
|
||||||
batch_size = 64
|
|
||||||
learning_rate = 4e-4
|
|
||||||
betas = (0.9, 0.99)
|
|
||||||
eps = 1e-8
|
|
||||||
grad_norm_clip = 1.0
|
|
||||||
lr_decay = True # linear warmup followed by cosine decay
|
|
||||||
warmup_tokens = 0
|
|
||||||
final_tokens = 0
|
|
||||||
epoch_save_frequency = 0
|
|
||||||
epoch_save_path = 'trained-'
|
|
||||||
num_workers = 0 # for DataLoader
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
|
|
||||||
def __init__(self, model, train_dataset, test_dataset, config):
|
|
||||||
self.model = model
|
|
||||||
self.train_dataset = train_dataset
|
|
||||||
self.test_dataset = test_dataset
|
|
||||||
self.config = config
|
|
||||||
self.avg_loss = -1
|
|
||||||
self.steps = 0
|
|
||||||
|
|
||||||
if 'wandb' in sys.modules:
|
|
||||||
cfg = model.config
|
|
||||||
for k in config.__dict__:
|
|
||||||
setattr(cfg, k, config.__dict__[k]) # combine cfg
|
|
||||||
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' +
|
|
||||||
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
|
|
||||||
|
|
||||||
self.device = 'cpu'
|
|
||||||
if torch.cuda.is_available(): # take over whatever gpus are on the system
|
|
||||||
self.device = torch.cuda.current_device()
|
|
||||||
|
|
||||||
def get_run_name(self):
|
|
||||||
raw_model = self.model.module if hasattr(
|
|
||||||
self.model, "module") else self.model
|
|
||||||
cfg = raw_model.config
|
|
||||||
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
|
|
||||||
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
|
|
||||||
return run_name
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
model, config = self.model, self.config
|
|
||||||
raw_model = model.module if hasattr(self.model, "module") else model
|
|
||||||
optimizer = raw_model.configure_optimizers(config)
|
|
||||||
|
|
||||||
def run_epoch(split):
|
|
||||||
is_train = split == 'train'
|
|
||||||
model.train(is_train)
|
|
||||||
data = self.train_dataset if is_train else self.test_dataset
|
|
||||||
|
|
||||||
if config.num_workers > 0:
|
|
||||||
loader = DataLoader(data, shuffle=False, pin_memory=True,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
num_workers=config.num_workers)
|
|
||||||
else:
|
|
||||||
loader = DataLoader(data, shuffle=False,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
num_workers=config.num_workers)
|
|
||||||
|
|
||||||
pbar = tqdm(enumerate(loader), total=len(
|
|
||||||
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
|
||||||
|
|
||||||
for it, (x, y) in pbar:
|
|
||||||
x = x.to(self.device) # place data on the correct device
|
|
||||||
y = y.to(self.device)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train):
|
|
||||||
_, loss = model(x, y) # forward the model
|
|
||||||
|
|
||||||
if is_train: # backprop and update the parameters
|
|
||||||
model.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
if config.grad_norm_clip > 0:
|
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
model.parameters(), config.grad_norm_clip)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if config.lr_decay: # decay the learning rate based on our progress
|
|
||||||
# number of tokens processed this step (i.e. label is not -100)
|
|
||||||
self.tokens += (y >= 0).sum()
|
|
||||||
lr_final_factor = config.lr_final / config.learning_rate
|
|
||||||
if self.tokens < config.warmup_tokens:
|
|
||||||
# linear warmup
|
|
||||||
lr_mult = lr_final_factor + \
|
|
||||||
(1 - lr_final_factor) * float(self.tokens) / \
|
|
||||||
float(config.warmup_tokens)
|
|
||||||
progress = 0
|
|
||||||
else:
|
|
||||||
# exponential learning rate decay
|
|
||||||
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
|
|
||||||
if progress >= 1:
|
|
||||||
lr_mult = lr_final_factor
|
|
||||||
else:
|
|
||||||
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
|
|
||||||
lr = config.learning_rate * lr_mult
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = lr
|
|
||||||
else:
|
|
||||||
lr = config.learning_rate
|
|
||||||
|
|
||||||
now_loss = loss.item() # report progress
|
|
||||||
self.lr = lr
|
|
||||||
|
|
||||||
if 'wandb' in sys.modules:
|
|
||||||
wandb.log({"loss": now_loss},
|
|
||||||
step=self.steps * self.config.batch_size)
|
|
||||||
self.steps += 1
|
|
||||||
|
|
||||||
if self.avg_loss < 0:
|
|
||||||
self.avg_loss = now_loss
|
|
||||||
else:
|
|
||||||
factor = 1 / (it + 1)
|
|
||||||
self.avg_loss = self.avg_loss * \
|
|
||||||
(1.0 - factor) + now_loss * factor
|
|
||||||
pbar.set_description(
|
|
||||||
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
|
|
||||||
|
|
||||||
self.tokens = 0 # counter used for learning rate decay
|
|
||||||
for epoch in range(config.max_epochs):
|
|
||||||
|
|
||||||
run_epoch('train')
|
|
||||||
|
|
||||||
log_file.write(
|
|
||||||
f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
|
|
||||||
log_file.flush()
|
|
||||||
|
|
||||||
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
|
|
||||||
# DataParallel wrappers keep raw model object in .module
|
|
||||||
raw_model = self.model.module if hasattr(
|
|
||||||
self.model, "module") else self.model
|
|
||||||
torch.save(raw_model.state_dict(),
|
|
||||||
self.config.epoch_save_path + str(epoch+1) + '.pth')
|
|
||||||
@ -1,122 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(Dataset):
|
|
||||||
def __init__(self, data, ctx_len, epoch_length_fixed):
|
|
||||||
print('building token list...', end=' ')
|
|
||||||
unique = sorted(list(set(data)))
|
|
||||||
# print()
|
|
||||||
# for u in unique:
|
|
||||||
# print(u, end=' ')
|
|
||||||
# print('\n\n')
|
|
||||||
|
|
||||||
xx = 0
|
|
||||||
xxObj = {}
|
|
||||||
for u in unique:
|
|
||||||
xxObj[xx] = u
|
|
||||||
xx += 1
|
|
||||||
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
|
|
||||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
|
||||||
|
|
||||||
data_size, vocab_size = len(data), len(unique)
|
|
||||||
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
|
|
||||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
|
||||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
self.epoch_length_fixed = epoch_length_fixed
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.epoch_length_fixed
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
# cheat: pick a random spot in dataset
|
|
||||||
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
|
|
||||||
chunk = self.data[i:i+self.ctx_len+1]
|
|
||||||
dix = [self.stoi[s] for s in chunk]
|
|
||||||
x = torch.tensor(dix[:-1], dtype=torch.long,
|
|
||||||
device=torch.device('cuda'))
|
|
||||||
y = torch.tensor(dix[1:], dtype=torch.long,
|
|
||||||
device=torch.device('cuda'))
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
|
|
||||||
class TOKENIZER():
|
|
||||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
|
||||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
|
||||||
self.word_table = json.load(result_file)
|
|
||||||
|
|
||||||
self.vocab_size = len(self.word_table)
|
|
||||||
|
|
||||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
|
||||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
|
||||||
|
|
||||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
|
||||||
|
|
||||||
def refine_context(self, context):
|
|
||||||
context = context.strip().split('\n')
|
|
||||||
for c in range(len(context)):
|
|
||||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
|
||||||
context = list(filter(lambda c: c != '', context))
|
|
||||||
context = '\n' + ('\n'.join(context)).strip()
|
|
||||||
if context == '':
|
|
||||||
context = '\n'
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
||||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
|
||||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
|
||||||
|
|
||||||
lastChar = int(x[-1])
|
|
||||||
|
|
||||||
probs = F.softmax(torch.tensor(out), dim=-1)
|
|
||||||
|
|
||||||
if self.itos[lastChar] == '\n':
|
|
||||||
top_p = top_p_newline
|
|
||||||
else:
|
|
||||||
top_p = top_p_usual
|
|
||||||
|
|
||||||
sorted_probs, s_index = torch.sort(probs, descending=True)
|
|
||||||
|
|
||||||
# for j in range(30):
|
|
||||||
# pp = sorted_probs[j].item()
|
|
||||||
# if pp < 0.005:
|
|
||||||
# break
|
|
||||||
# ss = self.itos[int(s_index[j])].replace('\n','_')
|
|
||||||
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
|
|
||||||
# print('')
|
|
||||||
|
|
||||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
|
|
||||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
|
||||||
|
|
||||||
probs[probs < cutoff] = 0
|
|
||||||
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
|
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
probs = probs.pow(1.0 / temperature)
|
|
||||||
|
|
||||||
return torch.multinomial(probs, num_samples=1)[0]
|
|
||||||
|
|
||||||
|
|
||||||
def to_float(x):
|
|
||||||
return x.cpu().detach().numpy().flatten()[0].astype(float)
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
@ -1,118 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
# if False: # True False ---> Set to False if you don't understand it
|
|
||||||
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
|
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
# import src.utils
|
|
||||||
# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import datetime
|
|
||||||
from src.model import GPT, GPTConfig
|
|
||||||
from src.trainer import Trainer, TrainerConfig
|
|
||||||
from src.utils import Dataset
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
### Step 1: set training data ##########################################################################
|
|
||||||
|
|
||||||
datafile = "../data/enwik8" # your data
|
|
||||||
datafile_encoding = 'utf-8'
|
|
||||||
# datafile_encoding = 'utf-16le'
|
|
||||||
|
|
||||||
### Step 2: set model size #############################################################################
|
|
||||||
# ----> test deeper models (n_layer at least 12) to see the advantage of RWKV-3 over RWKV-2
|
|
||||||
|
|
||||||
ctx_len = 1024 # increase T_MAX in model.py if your ctx_len > 1024
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
|
|
||||||
# 'RWKV' (better for English) or 'RWKV-ffnPre' (better in some cases)
|
|
||||||
model_type = 'RWKV'
|
|
||||||
|
|
||||||
# ---> there is a RWKV_HEAD_QK_DIM in model.py and model_run.py
|
|
||||||
# set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss
|
|
||||||
# set it to 0, then it's a pure RNN (attention-free)
|
|
||||||
|
|
||||||
### Step 3: set batch size #############################################################################
|
|
||||||
|
|
||||||
# ---> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py
|
|
||||||
# for example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2
|
|
||||||
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
|
|
||||||
batch_size = 12
|
|
||||||
|
|
||||||
### Step 4: set learning rate, number of mini-epochs #######################################################
|
|
||||||
# By default we are using exponential LR decay.
|
|
||||||
#
|
|
||||||
# Here are my suggestions for training a good model.
|
|
||||||
# Let's say you will train a L6-D512 model.
|
|
||||||
# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until the improvement of loss become slow.
|
|
||||||
# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run.
|
|
||||||
# 3) Set lr_init = 8e-4, lr_final = 1e-5, warmup_tokens = ctx_len * batch_size * 50, betas = (0.9, 0.999).
|
|
||||||
# 4) Search for "torch.load" here and modify it to load the partially-trained model. Continue the training.
|
|
||||||
#
|
|
||||||
# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4.
|
|
||||||
|
|
||||||
lr_init = 8e-4 # we can use larger lr because of preLN
|
|
||||||
lr_final = 1e-5
|
|
||||||
|
|
||||||
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
|
|
||||||
n_epoch = 500
|
|
||||||
epoch_length_fixed = 10000
|
|
||||||
|
|
||||||
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
|
|
||||||
epoch_save_frequency = 10
|
|
||||||
epoch_save_path = 'trained-'
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
grad_norm_clip = 1.0
|
|
||||||
warmup_tokens = ctx_len * batch_size * 0
|
|
||||||
|
|
||||||
betas = (0.9, 0.99)
|
|
||||||
eps = 4e-9
|
|
||||||
|
|
||||||
num_workers = 0
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Load data
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print('loading data... ' + datafile)
|
|
||||||
train_dataset = Dataset(open(
|
|
||||||
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Train model
|
|
||||||
########################################################################################################
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
|
|
||||||
n_layer=n_layer, n_embd=n_embd)).cuda()
|
|
||||||
|
|
||||||
### ---> load a trained model <---
|
|
||||||
# m2 = torch.load('trained-61.pth')
|
|
||||||
# model.load_state_dict(m2)
|
|
||||||
|
|
||||||
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
|
|
||||||
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
|
|
||||||
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
|
|
||||||
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
|
|
||||||
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
|
|
||||||
trainer = Trainer(model, train_dataset, None, tconf)
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() +
|
|
||||||
'-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
# this is for verifying the results of different models and make sure they agree with each other
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
RUN_DEVICE = 'cuda'
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from src.model_run import RWKV_RNN, RWKV_GPT
|
|
||||||
from src.model import GPT, GPTConfig
|
|
||||||
|
|
||||||
ctx_len = 1024
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
model_type = 'RWKV'
|
|
||||||
|
|
||||||
model_name = 'trained-1'
|
|
||||||
|
|
||||||
from src.utils import TOKENIZER
|
|
||||||
tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ')
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
|
|
||||||
print('loading ' + model_name)
|
|
||||||
m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE)
|
|
||||||
model_train.load_state_dict(m2)
|
|
||||||
|
|
||||||
model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
|
|
||||||
model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
context = '\nIn a'
|
|
||||||
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
||||||
print(f'input len {len(ctx)} data {ctx}')
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print('\nRWKV-GPT output')
|
|
||||||
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
print('\nRWKV-RNN output')
|
|
||||||
model_rnn.clear()
|
|
||||||
src_len = len(ctx)
|
|
||||||
for i in range(src_len):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
out = model_rnn.run(x)
|
|
||||||
if i < 3 or i >= src_len - 3:
|
|
||||||
print(torch.tensor(out).detach().cpu().numpy())
|
|
||||||
if i == 2:
|
|
||||||
print('...')
|
|
||||||
|
|
||||||
print('\nRWKV-train output')
|
|
||||||
ctx += [0] * (ctx_len - src_len) # pad to ctx_len
|
|
||||||
ctx = [ctx] * 4 # increase batch size (to make it work with B_GROUP_FORWARD & B_GROUP_BACKWARD)
|
|
||||||
out = model_train.forward(torch.tensor(ctx).cuda())[0][0][:src_len].detach().cpu().numpy()
|
|
||||||
print(out, '\n')
|
|
||||||
|
Before Width: | Height: | Size: 70 KiB |
@ -1,125 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#define MIN_VALUE (-1e38)
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
|
||||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
|
||||||
F *__restrict__ const _y) {
|
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int _b = idx / C;
|
|
||||||
const int _c = idx % C;
|
|
||||||
const int _offset = _b * T * C + _c;
|
|
||||||
|
|
||||||
F u = _u[_c];
|
|
||||||
F w = _w[_c];
|
|
||||||
const F *__restrict__ const k = _k + _offset;
|
|
||||||
const F *__restrict__ const v = _v + _offset;
|
|
||||||
F *__restrict__ const y = _y + _offset;
|
|
||||||
|
|
||||||
F p = 0, q = 0, o = MIN_VALUE;
|
|
||||||
// p and q are running sums divided by exp(o) (to avoid overflows)
|
|
||||||
for (int i = 0; i < T; i++) {
|
|
||||||
const int ii = i * C;
|
|
||||||
|
|
||||||
F no = max(o, u + k[ii]);
|
|
||||||
F A = exp(o - no);
|
|
||||||
F B = exp(u + k[ii] - no);
|
|
||||||
y[ii] = (A * p + B * v[ii]) / (A * q + B);
|
|
||||||
|
|
||||||
no = max(w + o, k[ii]);
|
|
||||||
A = exp(w + o - no);
|
|
||||||
B = exp(k[ii] - no);
|
|
||||||
p = A * p + B * v[ii];
|
|
||||||
q = A * q + B;
|
|
||||||
o = no;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
|
||||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
|
|
||||||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int _b = idx / C;
|
|
||||||
const int _c = idx % C;
|
|
||||||
const int _offset = _b * T * C + _c;
|
|
||||||
|
|
||||||
F u = _u[_c];
|
|
||||||
F w = _w[_c];
|
|
||||||
const F *__restrict__ const k = _k + _offset;
|
|
||||||
const F *__restrict__ const v = _v + _offset;
|
|
||||||
const F *__restrict__ const gy = _gy + _offset;
|
|
||||||
|
|
||||||
F *__restrict__ const gk = _gk + _offset;
|
|
||||||
F *__restrict__ const gv = _gv + _offset;
|
|
||||||
|
|
||||||
F y[Tmax], z[Tmax], zexp[Tmax];
|
|
||||||
|
|
||||||
F gw = 0, gu = 0;
|
|
||||||
F p = 0, q = 0;
|
|
||||||
F dpdw = 0, dqdw = 0;
|
|
||||||
F o = MIN_VALUE;
|
|
||||||
for (int i = 0; i < T; i++) {
|
|
||||||
const int ii = i * C;
|
|
||||||
F no = max(o, k[ii] + u);
|
|
||||||
F A = exp(o - no);
|
|
||||||
F B = exp(k[ii] + u - no);
|
|
||||||
|
|
||||||
F num = A * p + B * v[ii];
|
|
||||||
F iden = 1 / (A * q + B);
|
|
||||||
|
|
||||||
y[i] = num * iden;
|
|
||||||
z[i] = iden;
|
|
||||||
zexp[i] = k[ii] + u - no;
|
|
||||||
|
|
||||||
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
|
|
||||||
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
|
|
||||||
|
|
||||||
no = max(w + o, k[ii]);
|
|
||||||
A = exp(w + o - no);
|
|
||||||
B = exp(k[ii] - no);
|
|
||||||
dpdw = A * (p + dpdw);
|
|
||||||
dqdw = A * (q + dqdw);
|
|
||||||
p = A * p + B * v[ii];
|
|
||||||
q = A * q + B;
|
|
||||||
o = no;
|
|
||||||
}
|
|
||||||
|
|
||||||
F gp = 0, gq = 0;
|
|
||||||
o = MIN_VALUE;
|
|
||||||
for (int i = T - 1; i >= 0; i--) {
|
|
||||||
const int ii = i * C;
|
|
||||||
F A = gy[ii] * z[i] * exp(zexp[i]);
|
|
||||||
F B = exp(k[ii] + o);
|
|
||||||
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
|
|
||||||
gv[ii] = A + B * gp;
|
|
||||||
|
|
||||||
F no = max(w + o, zexp[i] - k[ii] - u);
|
|
||||||
A = exp(w + o - no);
|
|
||||||
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
|
|
||||||
gp = A * gp + B;
|
|
||||||
gq = A * gq - B * y[i];
|
|
||||||
o = no;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
|
|
||||||
const int _offsetBC = _b * C + _c;
|
|
||||||
_gw[_offsetBC] += gw * _w[_c];
|
|
||||||
_gu[_offsetBC] += gu;
|
|
||||||
}
|
|
||||||
|
|
||||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
|
||||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
|
||||||
assert(B * C % threadsPerBlock.x == 0);
|
|
||||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
|
||||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
|
||||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
|
||||||
assert(B * C % threadsPerBlock.x == 0);
|
|
||||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
|
||||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
|
|
||||||
}
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
|
||||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
|
|
||||||
|
|
||||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
|
||||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
|
||||||
}
|
|
||||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
|
||||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &forward, "wkv forward");
|
|
||||||
m.def("backward", &backward, "wkv backward");
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY(wkv, m) {
|
|
||||||
m.def("forward", forward);
|
|
||||||
m.def("backward", backward);
|
|
||||||
}
|
|
||||||
@ -1,149 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import math, os
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from src.utils import TOKENIZER, Dataset
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Step 1: set model
|
|
||||||
#
|
|
||||||
# Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch.
|
|
||||||
#
|
|
||||||
# Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models.
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
TOKEN_MODE = 'char' # char / bpe / pile
|
|
||||||
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
ctx_len = 1024
|
|
||||||
|
|
||||||
if TOKEN_MODE == 'char':
|
|
||||||
MODEL_NAME = 'trained-500' # your trained model
|
|
||||||
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
|
|
||||||
# set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it
|
|
||||||
UNKNOWN_CHAR = ' ' # here we just set it to ' ' for simplicity
|
|
||||||
|
|
||||||
elif TOKEN_MODE == 'bpe':
|
|
||||||
MODEL_NAME = 'trained-500' # your trained model
|
|
||||||
WORD_NAME = ['model-vocab.json', 'model-merges.txt'] # [vocab, merge] for your BPE model
|
|
||||||
UNKNOWN_CHAR = None
|
|
||||||
|
|
||||||
elif TOKEN_MODE == 'pile':
|
|
||||||
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
|
|
||||||
UNKNOWN_CHAR = None
|
|
||||||
|
|
||||||
#---> you can set MODEL_NAME to your fine-tuned model <---
|
|
||||||
|
|
||||||
MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
|
|
||||||
# MODEL_NAME = 'trained-11'
|
|
||||||
n_layer = 12
|
|
||||||
n_embd = 768
|
|
||||||
ctx_len = 1024
|
|
||||||
|
|
||||||
# MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066'
|
|
||||||
# n_layer = 24
|
|
||||||
# n_embd = 1024
|
|
||||||
# ctx_len = 1024
|
|
||||||
|
|
||||||
# MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040'
|
|
||||||
# n_layer = 24
|
|
||||||
# n_embd = 2048
|
|
||||||
# ctx_len = 1024
|
|
||||||
|
|
||||||
os.environ['RWKV_FLOAT_MODE'] = 'fp32' # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment)
|
|
||||||
os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda'
|
|
||||||
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre'
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Step 2: set prompt & sampling stuffs
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
# context = 'A'
|
|
||||||
# context = "\nIn the"
|
|
||||||
# context = '\nSugar:'
|
|
||||||
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
|
|
||||||
|
|
||||||
NUM_TRIALS = 999
|
|
||||||
LENGTH_PER_TRIAL = 333
|
|
||||||
|
|
||||||
TEMPERATURE = 1.0
|
|
||||||
top_p = 0.7
|
|
||||||
top_p_newline = 0.9 # only used in TOKEN_MODE = char
|
|
||||||
|
|
||||||
DEBUG_DEBUG = False # True False --> show softmax output
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print(f'Loading {MODEL_NAME}...')
|
|
||||||
from src.model_run import RWKV_RNN
|
|
||||||
model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len)
|
|
||||||
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
if tokenizer.charMode:
|
|
||||||
context = tokenizer.refine_context(context)
|
|
||||||
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
||||||
else:
|
|
||||||
ctx = tokenizer.tokenizer.encode(context)
|
|
||||||
src_len = len(ctx)
|
|
||||||
src_ctx = ctx.copy()
|
|
||||||
|
|
||||||
print('\nYour prompt has ' + str(src_len) + ' tokens.')
|
|
||||||
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n')
|
|
||||||
|
|
||||||
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
|
|
||||||
t_begin = time.time_ns()
|
|
||||||
print(('-' * 30) + context, end='')
|
|
||||||
ctx = src_ctx.copy()
|
|
||||||
model.clear()
|
|
||||||
if TRIAL == 0:
|
|
||||||
init_state = types.SimpleNamespace()
|
|
||||||
for i in range(src_len):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
if i == src_len - 1:
|
|
||||||
init_state.out = model.run(x)
|
|
||||||
else:
|
|
||||||
model.run(x)
|
|
||||||
model.save(init_state)
|
|
||||||
else:
|
|
||||||
model.load(init_state)
|
|
||||||
|
|
||||||
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
x = x[-ctx_len:]
|
|
||||||
|
|
||||||
if i == src_len:
|
|
||||||
out = copy.deepcopy(init_state.out)
|
|
||||||
else:
|
|
||||||
out = model.run(x)
|
|
||||||
if DEBUG_DEBUG:
|
|
||||||
print('model', np.array(x), '==>', np.array(
|
|
||||||
out), np.max(out), np.min(out))
|
|
||||||
|
|
||||||
if TOKEN_MODE == 'pile':
|
|
||||||
out[0] = -999999999 # disable <|endoftext|>
|
|
||||||
|
|
||||||
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
|
|
||||||
top_p_usual=top_p, top_p_newline=top_p_newline)
|
|
||||||
char = char.item()
|
|
||||||
if tokenizer.charMode:
|
|
||||||
print(tokenizer.itos[int(char)], end='', flush=True)
|
|
||||||
else:
|
|
||||||
print(tokenizer.tokenizer.decode(int(char)), end='', flush=True)
|
|
||||||
ctx += [char]
|
|
||||||
|
|
||||||
t_end = time.time_ns()
|
|
||||||
print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')
|
|
||||||
@ -1,216 +0,0 @@
|
|||||||
from lib2to3.pgen2 import token
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import shutil
|
|
||||||
import struct
|
|
||||||
from functools import lru_cache
|
|
||||||
from itertools import accumulate
|
|
||||||
|
|
||||||
def print_rank_0(*message):
|
|
||||||
"""If distributed is initialized print only on rank 0."""
|
|
||||||
if torch.distributed.is_initialized():
|
|
||||||
if torch.distributed.get_rank() == 0:
|
|
||||||
print(*message, flush=True)
|
|
||||||
else:
|
|
||||||
print(*message, flush=True)
|
|
||||||
|
|
||||||
def _warmup_mmap_file(path):
|
|
||||||
pass
|
|
||||||
# with open(path, "rb") as stream:
|
|
||||||
# while stream.read(100 * 1024 * 1024):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
dtypes = {
|
|
||||||
1: np.uint8,
|
|
||||||
2: np.int8,
|
|
||||||
3: np.int16,
|
|
||||||
4: np.int32,
|
|
||||||
5: np.int64,
|
|
||||||
6: float,
|
|
||||||
7: np.double,
|
|
||||||
8: np.uint16,
|
|
||||||
}
|
|
||||||
|
|
||||||
def code(dtype):
|
|
||||||
for k in dtypes.keys():
|
|
||||||
if dtypes[k] == dtype:
|
|
||||||
return k
|
|
||||||
raise ValueError(dtype)
|
|
||||||
|
|
||||||
def index_file_path(prefix_path):
|
|
||||||
return prefix_path + ".idx"
|
|
||||||
|
|
||||||
def data_file_path(prefix_path):
|
|
||||||
return prefix_path + ".bin"
|
|
||||||
|
|
||||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
|
||||||
class Index(object):
|
|
||||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
|
||||||
|
|
||||||
def __init__(self, path, skip_warmup=False):
|
|
||||||
with open(path, "rb") as stream:
|
|
||||||
magic_test = stream.read(9)
|
|
||||||
assert self._HDR_MAGIC == magic_test, (
|
|
||||||
"Index file doesn't match expected format. "
|
|
||||||
"Make sure that --dataset-impl is configured properly."
|
|
||||||
)
|
|
||||||
# Little endian unsigned 64 Bit integer
|
|
||||||
version = struct.unpack("<Q", stream.read(8))
|
|
||||||
assert (1,) == version
|
|
||||||
|
|
||||||
# Little endian unsigned 8 Bit integer
|
|
||||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
|
||||||
self._dtype = dtypes[dtype_code]
|
|
||||||
self._dtype_size = self._dtype().itemsize
|
|
||||||
|
|
||||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
|
||||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
|
||||||
offset = stream.tell()
|
|
||||||
|
|
||||||
if not skip_warmup:
|
|
||||||
print_rank_0(" warming up index mmap file...")
|
|
||||||
_warmup_mmap_file(path)
|
|
||||||
|
|
||||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
|
||||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
|
||||||
print_rank_0(" reading sizes...")
|
|
||||||
self._sizes = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
|
||||||
)
|
|
||||||
print_rank_0(" reading pointers...")
|
|
||||||
self._pointers = np.frombuffer(
|
|
||||||
self._bin_buffer,
|
|
||||||
dtype=np.int64,
|
|
||||||
count=self._len,
|
|
||||||
offset=offset + self._sizes.nbytes,
|
|
||||||
)
|
|
||||||
print_rank_0(" reading document index...")
|
|
||||||
self._doc_idx = np.frombuffer(
|
|
||||||
self._bin_buffer,
|
|
||||||
dtype=np.int64,
|
|
||||||
count=self._doc_count,
|
|
||||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self._bin_buffer_mmap._mmap.close()
|
|
||||||
del self._bin_buffer_mmap
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sizes(self):
|
|
||||||
return self._sizes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def doc_idx(self):
|
|
||||||
return self._doc_idx
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
|
||||||
def __getitem__(self, i):
|
|
||||||
return self._pointers[i], self._sizes[i]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._len
|
|
||||||
|
|
||||||
def __init__(self, path, skip_warmup=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._path = None
|
|
||||||
self._index = None
|
|
||||||
self._bin_buffer = None
|
|
||||||
|
|
||||||
self._do_init(path, skip_warmup)
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
return self._path
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
self._do_init(state)
|
|
||||||
|
|
||||||
def _do_init(self, path, skip_warmup):
|
|
||||||
self._path = path
|
|
||||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
|
||||||
|
|
||||||
if not skip_warmup:
|
|
||||||
print_rank_0(" warming up data mmap file...")
|
|
||||||
_warmup_mmap_file(data_file_path(self._path))
|
|
||||||
print_rank_0(" creating numpy buffer of mmap...")
|
|
||||||
self._bin_buffer_mmap = np.memmap(
|
|
||||||
data_file_path(self._path), mode="r", order="C"
|
|
||||||
)
|
|
||||||
print_rank_0(" creating memory view of numpy buffer...")
|
|
||||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self._bin_buffer_mmap._mmap.close()
|
|
||||||
del self._bin_buffer_mmap
|
|
||||||
del self._index
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._index)
|
|
||||||
|
|
||||||
# @lru_cache(maxsize=8)
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
if isinstance(idx, int):
|
|
||||||
ptr, size = self._index[idx]
|
|
||||||
np_array = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
|
||||||
)
|
|
||||||
return np_array
|
|
||||||
elif isinstance(idx, slice):
|
|
||||||
start, stop, step = idx.indices(len(self))
|
|
||||||
if step != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Slices into indexed_dataset must be contiguous")
|
|
||||||
ptr = self._index._pointers[start]
|
|
||||||
sizes = self._index._sizes[idx]
|
|
||||||
offsets = list(accumulate(sizes))
|
|
||||||
total_size = sum(sizes)
|
|
||||||
np_array = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
|
||||||
)
|
|
||||||
sents = np.split(np_array, offsets[:-1])
|
|
||||||
return sents
|
|
||||||
|
|
||||||
def get(self, idx, offset=0, length=None):
|
|
||||||
"""Retrieves a single item from the dataset with the option to only
|
|
||||||
return a portion of the item.
|
|
||||||
|
|
||||||
get(idx) is the same as [idx] but get() does not support slicing.
|
|
||||||
"""
|
|
||||||
ptr, size = self._index[idx]
|
|
||||||
if length is None:
|
|
||||||
length = size - offset
|
|
||||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
|
||||||
np_array = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
|
||||||
)
|
|
||||||
return np_array
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sizes(self):
|
|
||||||
return self._index.sizes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def doc_idx(self):
|
|
||||||
return self._index.doc_idx
|
|
||||||
|
|
||||||
def get_doc_idx(self):
|
|
||||||
return self._index._doc_idx
|
|
||||||
|
|
||||||
def set_doc_idx(self, doc_idx_):
|
|
||||||
self._index._doc_idx = doc_idx_
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_prefetch(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def exists(path):
|
|
||||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
|
||||||
data_file_path(path)
|
|
||||||
)
|
|
||||||
@ -1,414 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import math, os
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
try:
|
|
||||||
from deepspeed.ops.adam import FusedAdam
|
|
||||||
except:
|
|
||||||
pass # some poor windows users cant install deepspeed
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
RWKV_HEAD_QK_DIM = 0
|
|
||||||
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
||||||
|
|
||||||
class L2Wrap(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, loss, y):
|
|
||||||
ctx.save_for_backward(y)
|
|
||||||
return loss
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
y = ctx.saved_tensors[0]
|
|
||||||
# to encourage the logits to be close to 0
|
|
||||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
|
||||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
|
||||||
gy = torch.zeros_like(y)
|
|
||||||
gy.scatter_(-1, ids, maxx * factor)
|
|
||||||
return (grad_output, gy)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# CUDA Kernel
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
|
|
||||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
|
|
||||||
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
|
|
||||||
|
|
||||||
class WKV(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, B, T, C, w, u, k, v):
|
|
||||||
ctx.B = B
|
|
||||||
ctx.T = T
|
|
||||||
ctx.C = C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 1024) == 0
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
w = -torch.exp(w.contiguous())
|
|
||||||
u = u.contiguous()
|
|
||||||
k = k.contiguous()
|
|
||||||
v = v.contiguous()
|
|
||||||
else:
|
|
||||||
w = -torch.exp(w.float().contiguous())
|
|
||||||
u = u.float().contiguous()
|
|
||||||
k = k.float().contiguous()
|
|
||||||
v = v.float().contiguous()
|
|
||||||
ctx.save_for_backward(w, u, k, v)
|
|
||||||
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
|
|
||||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
return y
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
return y.half()
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
return y.bfloat16()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gy):
|
|
||||||
B = ctx.B
|
|
||||||
T = ctx.T
|
|
||||||
C = ctx.C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 1024) == 0
|
|
||||||
w, u, k, v = ctx.saved_tensors
|
|
||||||
gw = torch.zeros((B, C), device='cuda').contiguous()
|
|
||||||
gu = torch.zeros((B, C), device='cuda').contiguous()
|
|
||||||
gk = torch.zeros((B, T, C), device='cuda').contiguous()
|
|
||||||
gv = torch.zeros((B, T, C), device='cuda').contiguous()
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
|
||||||
else:
|
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
|
||||||
gw = torch.sum(gw, dim=0)
|
|
||||||
gu = torch.sum(gu, dim=0)
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
return (None, None, None, gw, gu, gk, gv)
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
|
||||||
|
|
||||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
|
||||||
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model
|
|
||||||
print("\n[--> first run, init model params (very slow for large models) <--]")
|
|
||||||
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n")
|
|
||||||
|
|
||||||
for mm in model.modules():
|
|
||||||
if "RecursiveScriptModule" in str(type(mm)):
|
|
||||||
if mm.original_name not in ["Linear"]:
|
|
||||||
continue
|
|
||||||
ww = None
|
|
||||||
for name, param in mm.named_parameters():
|
|
||||||
if name == "weight":
|
|
||||||
ww = param
|
|
||||||
else:
|
|
||||||
m = mm
|
|
||||||
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
|
||||||
continue
|
|
||||||
ww = m.weight
|
|
||||||
with torch.no_grad():
|
|
||||||
name = "[unknown weight]"
|
|
||||||
for name, parameter in model.named_parameters(): # find the name of the weight
|
|
||||||
if id(ww) == id(parameter):
|
|
||||||
break
|
|
||||||
|
|
||||||
shape = ww.shape
|
|
||||||
gain = 1.0
|
|
||||||
scale = 1.0 # extra scale for gain
|
|
||||||
|
|
||||||
if isinstance(m, nn.Embedding):
|
|
||||||
gain = math.sqrt(max(shape[0], shape[1]))
|
|
||||||
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb?
|
|
||||||
scale = 1e-4
|
|
||||||
else:
|
|
||||||
scale = 0
|
|
||||||
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
if shape[0] > shape[1]:
|
|
||||||
gain = math.sqrt(shape[0] / shape[1])
|
|
||||||
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection?
|
|
||||||
scale = 0.5
|
|
||||||
|
|
||||||
if hasattr(m, "scale_init"):
|
|
||||||
scale = m.scale_init
|
|
||||||
|
|
||||||
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}")
|
|
||||||
|
|
||||||
gain *= scale
|
|
||||||
if scale == -999:
|
|
||||||
nn.init.eye_(ww)
|
|
||||||
elif gain == 0:
|
|
||||||
# zero init is great for some RWKV matrices
|
|
||||||
nn.init.zeros_(ww)
|
|
||||||
elif gain > 0:
|
|
||||||
nn.init.orthogonal_(ww, gain=gain)
|
|
||||||
else:
|
|
||||||
nn.init.normal_(ww, mean=0.0, std=-scale)
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_TimeMix(torch.jit.ScriptModule):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.ctx_len = config.ctx_len
|
|
||||||
self.n_embd = config.n_embd
|
|
||||||
|
|
||||||
attn_sz = config.n_embd
|
|
||||||
|
|
||||||
with torch.no_grad(): # fancy init
|
|
||||||
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
|
|
||||||
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
|
||||||
|
|
||||||
# fancy time_decay
|
|
||||||
decay_speed = torch.ones(attn_sz)
|
|
||||||
for h in range(attn_sz):
|
|
||||||
decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
|
||||||
self.time_decay = nn.Parameter(decay_speed)
|
|
||||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
|
||||||
|
|
||||||
# fancy time_first
|
|
||||||
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
|
|
||||||
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
|
|
||||||
|
|
||||||
# fancy time_mix
|
|
||||||
x = torch.ones(1, 1, config.n_embd)
|
|
||||||
for i in range(config.n_embd):
|
|
||||||
x[0, 0, i] = i / config.n_embd
|
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
|
|
||||||
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
|
||||||
|
|
||||||
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.key.scale_init = 0
|
|
||||||
self.receptance.scale_init = 0
|
|
||||||
self.output.scale_init = 0
|
|
||||||
|
|
||||||
@torch.jit.script_method
|
|
||||||
def jit_func(self, x):
|
|
||||||
|
|
||||||
# Mix x with the previous timestep to produce xk, xv, xr
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
# Use xk, xv, xr to produce k, v, r
|
|
||||||
k = self.key(xk)
|
|
||||||
v = self.value(xv)
|
|
||||||
r = self.receptance(xr)
|
|
||||||
sr = torch.sigmoid(r)
|
|
||||||
|
|
||||||
return sr, k, v
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
||||||
|
|
||||||
sr, k, v = self.jit_func(x)
|
|
||||||
|
|
||||||
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
||||||
rwkv = self.output(rwkv)
|
|
||||||
return rwkv
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_ChannelMix(torch.jit.ScriptModule):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
|
|
||||||
with torch.no_grad(): # fancy init of time_mix
|
|
||||||
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
|
||||||
|
|
||||||
x = torch.ones(1, 1, config.n_embd)
|
|
||||||
for i in range(config.n_embd):
|
|
||||||
x[0, 0, i] = i / config.n_embd
|
|
||||||
|
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
|
|
||||||
hidden_sz = 4 * config.n_embd
|
|
||||||
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.value.scale_init = 0
|
|
||||||
self.receptance.scale_init = 0
|
|
||||||
|
|
||||||
@torch.jit.script_method
|
|
||||||
def forward(self, x):
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
k = self.key(xk)
|
|
||||||
k = torch.square(torch.relu(k))
|
|
||||||
kv = self.value(k)
|
|
||||||
|
|
||||||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
|
||||||
return rkv
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# The GPT Model with our blocks
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
class GPTConfig:
|
|
||||||
def __init__(self, vocab_size, ctx_len, **kwargs):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, config, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
||||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0:
|
|
||||||
self.ln0 = nn.LayerNorm(config.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
||||||
self.ffnPre = RWKV_ChannelMix(config, 0)
|
|
||||||
else:
|
|
||||||
self.att = RWKV_TimeMix(config, layer_id)
|
|
||||||
|
|
||||||
self.ffn = RWKV_ChannelMix(config, layer_id)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.layer_id == 0:
|
|
||||||
x = self.ln0(x)
|
|
||||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.ffnPre(self.ln1(x)) # better in some cases
|
|
||||||
else:
|
|
||||||
x = x + self.att(self.ln1(x))
|
|
||||||
x = x + self.ffn(self.ln2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.step = 0
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
|
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*[Block(config, i)
|
|
||||||
for i in range(config.n_layer)])
|
|
||||||
|
|
||||||
self.ln_out = nn.LayerNorm(config.n_embd)
|
|
||||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_q.scale_init = 0
|
|
||||||
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_k.scale_init = 0.1
|
|
||||||
self.register_buffer("copy_mask", torch.tril(
|
|
||||||
torch.ones(config.ctx_len, config.ctx_len)))
|
|
||||||
|
|
||||||
self.ctx_len = config.ctx_len
|
|
||||||
|
|
||||||
try:
|
|
||||||
if os.environ['RWKV_LOAD_MODEL'] == str(False):
|
|
||||||
RWKV_Init(self, config)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info("number of parameters: %e", sum(p.numel()
|
|
||||||
for p in self.parameters()))
|
|
||||||
|
|
||||||
def get_ctx_len(self):
|
|
||||||
return self.ctx_len
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.01)
|
|
||||||
if isinstance(module, (nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=1e-5)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
|
|
||||||
def configure_optimizers(self, train_config):
|
|
||||||
no_decay = set()
|
|
||||||
|
|
||||||
for mn, m in self.named_modules(): # here we disable weight_decay
|
|
||||||
for pn, p in m.named_parameters():
|
|
||||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
|
||||||
no_decay.add(fpn)
|
|
||||||
|
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
||||||
optim_groups = [
|
|
||||||
{"params": [param_dict[pn]
|
|
||||||
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
|
||||||
except:
|
|
||||||
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n')
|
|
||||||
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def forward(self, idx, targets=None):
|
|
||||||
idx = idx.to(self.emb.weight.device)
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
B, T = idx.size()
|
|
||||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
|
||||||
|
|
||||||
x = self.emb(idx)
|
|
||||||
x = self.blocks(x)
|
|
||||||
x = self.ln_out(x)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
q = self.head_q(x)[:, :T, :]
|
|
||||||
k = self.head_k(x)[:, :T, :]
|
|
||||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
|
||||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
||||||
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size)
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16()
|
|
||||||
|
|
||||||
x = self.head(x) + c
|
|
||||||
else:
|
|
||||||
x = self.head(x)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if targets is not None:
|
|
||||||
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
|
|
||||||
|
|
||||||
return L2Wrap.apply(loss, x)
|
|
||||||
@ -1,392 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import types
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
import math, os
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
RWKV_HEAD_QK_DIM = 0
|
|
||||||
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
|
||||||
|
|
||||||
DEBUG_TIME = False # True False - show trained time-coeffs
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# CUDA Kernel
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
|
|
||||||
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
|
|
||||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
|
|
||||||
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
|
|
||||||
|
|
||||||
class WKV(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, B, T, C, w, u, k, v):
|
|
||||||
ctx.B = B
|
|
||||||
ctx.T = T
|
|
||||||
ctx.C = C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 1024) == 0
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
w = -torch.exp(w.contiguous())
|
|
||||||
u = u.contiguous()
|
|
||||||
k = k.contiguous()
|
|
||||||
v = v.contiguous()
|
|
||||||
else:
|
|
||||||
w = -torch.exp(w.float().contiguous())
|
|
||||||
u = u.float().contiguous()
|
|
||||||
k = k.float().contiguous()
|
|
||||||
v = v.float().contiguous()
|
|
||||||
ctx.save_for_backward(w, u, k, v)
|
|
||||||
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
|
|
||||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
return y
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
return y.half()
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
return y.bfloat16()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gy):
|
|
||||||
B = ctx.B
|
|
||||||
T = ctx.T
|
|
||||||
C = ctx.C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 1024) == 0
|
|
||||||
w, u, k, v = ctx.saved_tensors
|
|
||||||
gw = torch.zeros((B, C), device='cuda').contiguous()
|
|
||||||
gu = torch.zeros((B, C), device='cuda').contiguous()
|
|
||||||
gk = torch.zeros((B, T, C), device='cuda').contiguous()
|
|
||||||
gv = torch.zeros((B, T, C), device='cuda').contiguous()
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
|
||||||
else:
|
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
|
||||||
gw = torch.sum(gw, dim=0)
|
|
||||||
gu = torch.sum(gu, dim=0)
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
return (None, None, None, gw, gu, gk, gv)
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
|
||||||
|
|
||||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
|
||||||
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
|
|
||||||
RWKV_CFG = types.SimpleNamespace()
|
|
||||||
|
|
||||||
class RWKV_ChannelMix(nn.Module):
|
|
||||||
def __init__(self, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
||||||
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
|
||||||
|
|
||||||
hidden_sz = 4 * RWKV_CFG.n_embd
|
|
||||||
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
|
|
||||||
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
k = self.key(xk)
|
|
||||||
k = torch.square(torch.relu(k))
|
|
||||||
kv = self.value(k)
|
|
||||||
|
|
||||||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
|
||||||
return rkv
|
|
||||||
|
|
||||||
class RWKV_TimeMix(nn.Module):
|
|
||||||
def __init__(self, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd))
|
|
||||||
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3))
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
|
||||||
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
|
||||||
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
|
||||||
|
|
||||||
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
|
|
||||||
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T, C = x.size()
|
|
||||||
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
|
|
||||||
k = self.key(xk)
|
|
||||||
v = self.value(xv)
|
|
||||||
r = self.receptance(xr)
|
|
||||||
|
|
||||||
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
|
||||||
|
|
||||||
rwkv = self.output(rwkv)
|
|
||||||
return rwkv
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
|
|
||||||
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
|
|
||||||
if self.layer_id == 0:
|
|
||||||
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
|
||||||
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
|
|
||||||
else:
|
|
||||||
self.att = RWKV_TimeMix(layer_id)
|
|
||||||
|
|
||||||
self.ffn = RWKV_ChannelMix(layer_id)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.layer_id == 0:
|
|
||||||
x = self.ln0(x)
|
|
||||||
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.ffnPre(self.ln1(x))
|
|
||||||
else:
|
|
||||||
x = x + self.att(self.ln1(x))
|
|
||||||
x = x + self.ffn(self.ln2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class RWKV_GPT(nn.Module):
|
|
||||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
|
|
||||||
global RWKV_CFG
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
|
|
||||||
RWKV_CFG.model_type = model_type
|
|
||||||
RWKV_CFG.vocab_size = vocab_size
|
|
||||||
RWKV_CFG.n_layer = n_layer
|
|
||||||
RWKV_CFG.n_embd = n_embd
|
|
||||||
RWKV_CFG.ctx_len = ctx_len
|
|
||||||
|
|
||||||
print('\nloading RWKV-GPT', MODEL_NAME)
|
|
||||||
|
|
||||||
self.emb = nn.Embedding(vocab_size, n_embd)
|
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
|
|
||||||
|
|
||||||
self.ln_out = nn.LayerNorm(n_embd)
|
|
||||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_q.scale_init = 0
|
|
||||||
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
|
||||||
self.head_k.scale_init = 0.1
|
|
||||||
self.register_buffer("copy_mask", torch.tril(
|
|
||||||
torch.ones(ctx_len, ctx_len)))
|
|
||||||
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
self.eval()
|
|
||||||
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, idx):
|
|
||||||
B, T = idx.size()
|
|
||||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
|
||||||
|
|
||||||
x = self.emb(idx)
|
|
||||||
x = self.blocks(x)
|
|
||||||
x = self.ln_out(x)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
q = self.head_q(x)[:, :T, :]
|
|
||||||
k = self.head_k(x)[:, :T, :]
|
|
||||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
|
||||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
||||||
|
|
||||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size)
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half()
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16()
|
|
||||||
|
|
||||||
x = self.head(x) + c
|
|
||||||
else:
|
|
||||||
x = self.head(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
|
|
||||||
class RWKV_RNN(): # this is running in FP32 at this moment
|
|
||||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
|
||||||
self.RUN_DEVICE = RUN_DEVICE
|
|
||||||
self.model_type = model_type
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
|
|
||||||
self.w = types.SimpleNamespace()
|
|
||||||
|
|
||||||
w = torch.load(MODEL_NAME + '.pth',
|
|
||||||
map_location=torch.device(RUN_DEVICE))
|
|
||||||
for x in w.keys():
|
|
||||||
w[x] = w[x].float()
|
|
||||||
if '.time_' in x:
|
|
||||||
w[x] = w[x].squeeze()
|
|
||||||
if '.time_decay' in x:
|
|
||||||
w[x] = -torch.exp(w[x])
|
|
||||||
if DEBUG_TIME and '.time_' in x:
|
|
||||||
print(x, w[x].squeeze().cpu().numpy())
|
|
||||||
|
|
||||||
xx = x.split('.')
|
|
||||||
here = self.w
|
|
||||||
for i in range(len(xx)):
|
|
||||||
if xx[i].isdigit():
|
|
||||||
ii = int(xx[i])
|
|
||||||
if ii not in here:
|
|
||||||
here[ii] = types.SimpleNamespace()
|
|
||||||
here = here[ii]
|
|
||||||
else:
|
|
||||||
if i == len(xx) - 1:
|
|
||||||
setattr(here, xx[i], w[x])
|
|
||||||
elif not hasattr(here, xx[i]):
|
|
||||||
if xx[i+1].isdigit():
|
|
||||||
setattr(here, xx[i], {})
|
|
||||||
else:
|
|
||||||
setattr(here, xx[i], types.SimpleNamespace())
|
|
||||||
here = getattr(here, xx[i])
|
|
||||||
|
|
||||||
self.clear()
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self.xx = {}
|
|
||||||
self.aa = {}
|
|
||||||
self.bb = {}
|
|
||||||
self.pp = {}
|
|
||||||
self.hk = None
|
|
||||||
|
|
||||||
def save(self, target):
|
|
||||||
target.xx = copy.deepcopy(self.xx)
|
|
||||||
target.aa = copy.deepcopy(self.aa)
|
|
||||||
target.bb = copy.deepcopy(self.bb)
|
|
||||||
target.pp = copy.deepcopy(self.pp)
|
|
||||||
target.hk = copy.deepcopy(self.hk)
|
|
||||||
|
|
||||||
def load(self, target):
|
|
||||||
self.xx = copy.deepcopy(target.xx)
|
|
||||||
self.aa = copy.deepcopy(target.aa)
|
|
||||||
self.bb = copy.deepcopy(target.bb)
|
|
||||||
self.pp = copy.deepcopy(target.pp)
|
|
||||||
self.hk = copy.deepcopy(target.hk)
|
|
||||||
|
|
||||||
def LN(self, xx, w):
|
|
||||||
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
|
||||||
|
|
||||||
def FF(self, xx, w, name):
|
|
||||||
if name not in self.xx:
|
|
||||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
|
||||||
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
|
||||||
self.xx[name] = xx
|
|
||||||
|
|
||||||
r = torch.sigmoid(w.receptance.weight @ xr)
|
|
||||||
k = torch.square(torch.relu(w.key.weight @ xk))
|
|
||||||
kv = w.value.weight @ k
|
|
||||||
|
|
||||||
return r * kv
|
|
||||||
|
|
||||||
def SA(self, xx, w, name):
|
|
||||||
if name not in self.xx:
|
|
||||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
|
||||||
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
|
|
||||||
|
|
||||||
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
|
||||||
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
|
|
||||||
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
|
||||||
self.xx[name] = xx
|
|
||||||
|
|
||||||
r = torch.sigmoid(w.receptance.weight @ xr)
|
|
||||||
|
|
||||||
k = w.key.weight @ xk
|
|
||||||
v = w.value.weight @ xv
|
|
||||||
|
|
||||||
pp = self.pp[name]
|
|
||||||
aa = self.aa[name]
|
|
||||||
bb = self.bb[name]
|
|
||||||
ww = w.time_first + k
|
|
||||||
p = torch.maximum(pp, ww)
|
|
||||||
e1 = torch.exp(pp - p)
|
|
||||||
e2 = torch.exp(ww - p)
|
|
||||||
a = e1 * aa + e2 * v
|
|
||||||
b = e1 * bb + e2
|
|
||||||
ww = pp + w.time_decay
|
|
||||||
p = torch.maximum(ww, k)
|
|
||||||
e1 = torch.exp(ww - p)
|
|
||||||
e2 = torch.exp(k - p)
|
|
||||||
self.aa[name] = e1 * aa + e2 * v
|
|
||||||
self.bb[name] = e1 * bb + e2
|
|
||||||
self.pp[name] = p
|
|
||||||
|
|
||||||
rwkv = r * a / b
|
|
||||||
|
|
||||||
return w.output.weight @ rwkv
|
|
||||||
|
|
||||||
def run(self, ctx):
|
|
||||||
w = self.w
|
|
||||||
x = w.emb.weight[ctx[-1]]
|
|
||||||
|
|
||||||
for i in range(self.n_layer):
|
|
||||||
if i == 0:
|
|
||||||
x = self.LN(x, w.blocks[i].ln0)
|
|
||||||
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
|
||||||
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
|
|
||||||
else:
|
|
||||||
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
|
|
||||||
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
|
|
||||||
|
|
||||||
x = self.LN(x, w.ln_out)
|
|
||||||
|
|
||||||
if RWKV_HEAD_QK_DIM > 0:
|
|
||||||
if self.hk == None:
|
|
||||||
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
|
||||||
else:
|
|
||||||
self.hk = torch.cat(
|
|
||||||
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
|
||||||
if self.hk.shape[0] > self.ctx_len:
|
|
||||||
self.hk = self.hk[-self.ctx_len:, :]
|
|
||||||
|
|
||||||
q = w.head_q.weight @ x
|
|
||||||
|
|
||||||
x = w.head.weight @ x
|
|
||||||
x = x.cpu().numpy().tolist()
|
|
||||||
|
|
||||||
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
|
||||||
for i in range(len(c)):
|
|
||||||
x[ctx[i]] += c[i]
|
|
||||||
else:
|
|
||||||
x = w.head.weight @ x
|
|
||||||
x = x.cpu().numpy().tolist()
|
|
||||||
|
|
||||||
return x
|
|
||||||
@ -1,187 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import os
|
|
||||||
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
|
|
||||||
USE_WANDB = (int(os.environ['USE_WANDB']) == 1)
|
|
||||||
|
|
||||||
from torch.utils.data.dataloader import DataLoader
|
|
||||||
import torch
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
import logging
|
|
||||||
import datetime
|
|
||||||
import math
|
|
||||||
from pytorch_lightning.lite import LightningLite
|
|
||||||
import gc
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
|
|
||||||
torch.backends.cudnn.allow_tf32 = False
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
|
||||||
else:
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
class TrainerConfig:
|
|
||||||
batch_size = 64
|
|
||||||
learning_rate = 4e-4
|
|
||||||
betas = (0.9, 0.99)
|
|
||||||
eps = 1e-8
|
|
||||||
grad_norm_clip = 1.0
|
|
||||||
warmup_tokens = 0
|
|
||||||
final_tokens = 0
|
|
||||||
epoch_save_frequency = 0
|
|
||||||
epoch_save_path = 'trained-'
|
|
||||||
num_workers = 0 # for DataLoader
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
from src.model import GPT, GPTConfig
|
|
||||||
|
|
||||||
class Trainer(LightningLite):
|
|
||||||
|
|
||||||
def get_run_name(self):
|
|
||||||
raw_model = self.model.module if hasattr(
|
|
||||||
self.model, "module") else self.model
|
|
||||||
cfg = raw_model.config
|
|
||||||
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \
|
|
||||||
cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
|
|
||||||
return run_name
|
|
||||||
|
|
||||||
def run(self, m_cfg, train_dataset, test_dataset, config):
|
|
||||||
self.cuda_id = int(str(self.device).strip('cuda:'))
|
|
||||||
print('[0]')
|
|
||||||
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type,
|
|
||||||
n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd))
|
|
||||||
print('[1]')
|
|
||||||
with torch.no_grad():
|
|
||||||
if m_cfg.LOAD_MODEL:
|
|
||||||
print('loading', m_cfg.MODEL_NAME)
|
|
||||||
m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location='cpu')
|
|
||||||
model.load_state_dict(m2)
|
|
||||||
del m2
|
|
||||||
model.to(self.device)
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.train_dataset = train_dataset
|
|
||||||
self.test_dataset = test_dataset
|
|
||||||
self.config = config
|
|
||||||
self.avg_loss = -1
|
|
||||||
self.EPOCH_BEGIN = m_cfg.EPOCH_BEGIN
|
|
||||||
|
|
||||||
self.steps = self.EPOCH_BEGIN * (len(self.train_dataset) // (config.batch_size // NUM_GPUS))
|
|
||||||
|
|
||||||
if self.cuda_id == 0:
|
|
||||||
log_file = open("mylog.txt", "a")
|
|
||||||
if USE_WANDB:
|
|
||||||
print('logging to wandb... (comment it if you don\'t have wandb)')
|
|
||||||
import wandb # comment this if you don't have wandb
|
|
||||||
cfg = model.config
|
|
||||||
for k in config.__dict__:
|
|
||||||
setattr(cfg, k, config.__dict__[k]) # combine cfg
|
|
||||||
wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False)
|
|
||||||
|
|
||||||
model, config = self.model, self.config
|
|
||||||
raw_model = model.module if hasattr(self.model, "module") else model
|
|
||||||
optimizer = raw_model.configure_optimizers(config)
|
|
||||||
model, optimizer = self.setup(model, optimizer)
|
|
||||||
print('[3]')
|
|
||||||
|
|
||||||
def run_epoch(split):
|
|
||||||
is_train = split == 'train'
|
|
||||||
model.train(is_train)
|
|
||||||
data = self.train_dataset if is_train else self.test_dataset
|
|
||||||
data.idx_begin = self.steps * config.batch_size + 1
|
|
||||||
data.cuda_id = self.cuda_id
|
|
||||||
|
|
||||||
if config.num_workers > 0:
|
|
||||||
loader = DataLoader(data, shuffle=False, pin_memory=True,
|
|
||||||
batch_size=config.batch_size // NUM_GPUS,
|
|
||||||
num_workers=config.num_workers)
|
|
||||||
else:
|
|
||||||
loader = DataLoader(data, shuffle=False,
|
|
||||||
batch_size=config.batch_size // NUM_GPUS,
|
|
||||||
num_workers=config.num_workers)
|
|
||||||
|
|
||||||
pbar = tqdm(enumerate(loader), total=len(
|
|
||||||
loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)
|
|
||||||
loader = self.setup_dataloaders(loader)
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
for it, (x, y) in pbar:
|
|
||||||
with torch.set_grad_enabled(is_train):
|
|
||||||
loss = model(x, y) # forward the model
|
|
||||||
|
|
||||||
if os.environ['RWKV_DEEPSPEED'] == '0':
|
|
||||||
all_loss = [loss.clone()]
|
|
||||||
else:
|
|
||||||
all_loss = [loss.clone() for _ in range(NUM_GPUS)]
|
|
||||||
torch.distributed.all_gather(all_loss, loss)
|
|
||||||
|
|
||||||
if is_train: # backprop and update the parameters
|
|
||||||
model.zero_grad()
|
|
||||||
self.backward(loss)
|
|
||||||
|
|
||||||
# deepspeed will handle gradient_clipping
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# decay the learning rate based on our progress
|
|
||||||
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
|
|
||||||
lr_final_factor = config.lr_final / config.learning_rate
|
|
||||||
if self.tokens < config.warmup_tokens:
|
|
||||||
# linear warmup
|
|
||||||
lr_mult = lr_final_factor + \
|
|
||||||
(1 - lr_final_factor) * float(self.tokens) / \
|
|
||||||
float(config.warmup_tokens)
|
|
||||||
progress = 0
|
|
||||||
else:
|
|
||||||
# exponential learning rate decay
|
|
||||||
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
|
|
||||||
if progress >= 1:
|
|
||||||
lr_mult = lr_final_factor
|
|
||||||
else:
|
|
||||||
lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
|
|
||||||
lr = config.learning_rate * lr_mult
|
|
||||||
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = lr
|
|
||||||
|
|
||||||
self.lr = lr
|
|
||||||
self.steps += 1
|
|
||||||
|
|
||||||
now_loss = 0
|
|
||||||
for gg in range(NUM_GPUS):
|
|
||||||
now_loss += all_loss[gg].item()
|
|
||||||
now_loss = now_loss / NUM_GPUS # report progress
|
|
||||||
if USE_WANDB and self.cuda_id == 0:
|
|
||||||
wandb.log({"loss": now_loss}, step = self.steps)
|
|
||||||
|
|
||||||
if self.avg_loss < 0:
|
|
||||||
self.avg_loss = now_loss
|
|
||||||
else:
|
|
||||||
factor = 1 / (it + 1)
|
|
||||||
self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor
|
|
||||||
|
|
||||||
pbar.set_description(f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}")
|
|
||||||
|
|
||||||
self.tokens = 0 # counter used for learning rate decay
|
|
||||||
for epoch in range(99999999):
|
|
||||||
|
|
||||||
run_epoch('train')
|
|
||||||
if math.isnan(self.avg_loss):
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
if self.cuda_id == 0:
|
|
||||||
log_file.write(f'{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n')
|
|
||||||
log_file.flush()
|
|
||||||
|
|
||||||
if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
|
|
||||||
raw_model = self.model.module if hasattr(self.model, "module") else self.model
|
|
||||||
torch.save(raw_model.state_dict(), self.config.epoch_save_path + str(epoch+1+self.EPOCH_BEGIN) + '.pth')
|
|
||||||
@ -1,153 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import os
|
|
||||||
try:
|
|
||||||
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
|
|
||||||
except:
|
|
||||||
NUM_GPUS = 1
|
|
||||||
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
class Dataset(Dataset):
|
|
||||||
def __init__(self, data, ctx_len, epoch_length_fixed):
|
|
||||||
self.ctx_len = ctx_len
|
|
||||||
self.epoch_length_fixed = epoch_length_fixed
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
if 'MMapIndexedDataset' in str(type(self.data)):
|
|
||||||
self.vocab_size = int(os.environ['VOCAB_SIZE'])
|
|
||||||
print('current vocab size =', self.vocab_size, "(make sure it's correct)")
|
|
||||||
self.data_size = len(self.data._bin_buffer) // 2
|
|
||||||
print(f'data has {self.data_size} tokens.')
|
|
||||||
elif 'numpy' in str(type(self.data)):
|
|
||||||
self.vocab_size = int(os.environ['VOCAB_SIZE'])
|
|
||||||
print('current vocab size =', self.vocab_size, "(make sure it's correct)")
|
|
||||||
self.data_size = len(self.data)
|
|
||||||
print(f'data has {self.data_size} tokens.')
|
|
||||||
else:
|
|
||||||
print('building token list...', end=' ')
|
|
||||||
unique = sorted(list(set(data)))
|
|
||||||
self.vocab_size = len(unique)
|
|
||||||
# print()
|
|
||||||
# for u in unique:
|
|
||||||
# print(u, end=' ')
|
|
||||||
# print('\n\n')
|
|
||||||
|
|
||||||
xx = 0
|
|
||||||
xxObj = {}
|
|
||||||
for u in unique:
|
|
||||||
xxObj[xx] = u
|
|
||||||
xx += 1
|
|
||||||
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
|
|
||||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
|
||||||
self.data_size = len(self.data)
|
|
||||||
print('data has %d tokens, %d unique.' % (self.data_size, self.vocab_size))
|
|
||||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
|
||||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.epoch_length_fixed // NUM_GPUS
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
#
|
|
||||||
# we are cheating: pick a random spot in dataset
|
|
||||||
#
|
|
||||||
i = np.random.randint(0, self.data_size - (self.ctx_len + 1))
|
|
||||||
if 'MMapIndexedDataset' in str(type(self.data)):
|
|
||||||
dix = self.data.get(idx=0, offset=i, length=self.ctx_len + 1).astype(int)
|
|
||||||
elif 'numpy' in str(type(self.data)):
|
|
||||||
dix = self.data[i:i+self.ctx_len+1]
|
|
||||||
else:
|
|
||||||
dix = [self.stoi[s] for s in self.data[i:i+self.ctx_len+1]]
|
|
||||||
|
|
||||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
|
||||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
|
|
||||||
class TOKENIZER():
|
|
||||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
|
||||||
if 'list' in str(type(WORD_NAME)):
|
|
||||||
self.charMode = False
|
|
||||||
if WORD_NAME[0] == WORD_NAME[1]:
|
|
||||||
from transformers import PreTrainedTokenizerFast
|
|
||||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
|
||||||
else:
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
|
||||||
self.vocab_size = len(self.tokenizer)
|
|
||||||
else:
|
|
||||||
self.charMode = True
|
|
||||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
|
||||||
self.word_table = json.load(result_file)
|
|
||||||
|
|
||||||
self.vocab_size = len(self.word_table)
|
|
||||||
|
|
||||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
|
||||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
|
||||||
|
|
||||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
|
||||||
|
|
||||||
def refine_context(self, context):
|
|
||||||
context = context.strip().split('\n')
|
|
||||||
for c in range(len(context)):
|
|
||||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
|
||||||
context = list(filter(lambda c: c != '', context))
|
|
||||||
context = '\n' + ('\n'.join(context)).strip()
|
|
||||||
if context == '':
|
|
||||||
context = '\n'
|
|
||||||
return context
|
|
||||||
|
|
||||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
|
||||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
|
||||||
|
|
||||||
lastChar = int(x[-1])
|
|
||||||
|
|
||||||
probs = F.softmax(torch.tensor(out), dim=-1)
|
|
||||||
|
|
||||||
if self.charMode:
|
|
||||||
if self.itos[lastChar] == '\n':
|
|
||||||
top_p = top_p_newline
|
|
||||||
else:
|
|
||||||
top_p = top_p_usual
|
|
||||||
else:
|
|
||||||
top_p = top_p_usual
|
|
||||||
|
|
||||||
sorted_probs, s_index = torch.sort(probs, descending=True)
|
|
||||||
|
|
||||||
# for j in range(30):
|
|
||||||
# pp = sorted_probs[j].item()
|
|
||||||
# if pp < 0.005:
|
|
||||||
# break
|
|
||||||
# ss = self.itos[int(s_index[j])].replace('\n','_')
|
|
||||||
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
|
|
||||||
# print('')
|
|
||||||
|
|
||||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
|
|
||||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
|
||||||
|
|
||||||
probs[probs < cutoff] = 0
|
|
||||||
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
|
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
probs = probs.pow(1.0 / temperature)
|
|
||||||
|
|
||||||
return torch.multinomial(probs, num_samples=1)[0]
|
|
||||||
|
|
||||||
|
|
||||||
def to_float(x):
|
|
||||||
return x.cpu().detach().numpy().flatten()[0].astype(float)
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
@ -1,280 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging, types
|
|
||||||
from src.utils import Dataset
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from src.binidx import MMapIndexedDataset
|
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
|
|
||||||
|
|
||||||
# if False: # True False ---> Set to False if you don't understand it
|
|
||||||
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
|
|
||||||
# import src.utils
|
|
||||||
# src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Step 1: set training data & cfg
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
EXPRESS_PILE_MODE = False # True: express mode for fine-tuning a pile model // False: usual training
|
|
||||||
|
|
||||||
EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
|
|
||||||
EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-169M'
|
|
||||||
# EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066'
|
|
||||||
# EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-430M'
|
|
||||||
# EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040'
|
|
||||||
# EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-1B5'
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
datafile = "../data/enwik8" # your data
|
|
||||||
datafile_encoding = 'utf-8' # 'utf-8' / 'utf-16le' / 'numpy' (for fine-tuning pile models) / 'binidx' (the Megatron-LM 'binidx' format)
|
|
||||||
|
|
||||||
# datafile = 'my-gpt_seq_document'
|
|
||||||
# datafile_encoding = 'binidx'
|
|
||||||
|
|
||||||
if EXPRESS_PILE_MODE:
|
|
||||||
datafile = 'train.npy' # use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into .npy
|
|
||||||
datafile_encoding = 'numpy'
|
|
||||||
|
|
||||||
#
|
|
||||||
# set VOCAB_SIZE = 0 (auto-compute) if you are training a char-level LM from scratch
|
|
||||||
# set VOCAB_SIZE = 50277 for fine-tuning pile models
|
|
||||||
# set VOCAB_SIZE = your_vocab_size for 'binidx' data
|
|
||||||
#
|
|
||||||
os.environ['VOCAB_SIZE'] = '0'
|
|
||||||
if EXPRESS_PILE_MODE:
|
|
||||||
os.environ['VOCAB_SIZE'] = '50277'
|
|
||||||
|
|
||||||
#
|
|
||||||
# Currently it's slow to initialize a new model. Hence I suggest this procedure for multi-GPU training:
|
|
||||||
# 1) set RWKV_NUM_GPUS = '1' and let it run for 1 miniEpoch and it will save a trained-1.pth
|
|
||||||
# 2) set RWKV_NUM_GPUS = '8' (or your #GPU), batch_size = single_gpu_batchsz * RWKV_NUM_GPUS,
|
|
||||||
# EPOCH_BEGIN = 1, LOAD_MODEL = True, and it will load 'trained-1.pth' and continue the training from it
|
|
||||||
#
|
|
||||||
os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
|
|
||||||
|
|
||||||
#
|
|
||||||
# 'bf16' (fast & stable)
|
|
||||||
# 'fp16' (fast & will overflow after training a large model for very long. can be solved in the future)
|
|
||||||
# 'tf32' (decent speed & stable)
|
|
||||||
# 'fp32' (!!!very slow!!! only for verification)
|
|
||||||
os.environ['RWKV_FLOAT_MODE'] = 'bf16'
|
|
||||||
|
|
||||||
os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True
|
|
||||||
|
|
||||||
if int(os.environ['RWKV_NUM_GPUS']) == 1: # Usually you don't need DeepSpeed for 1 GPU training.
|
|
||||||
os.environ['RWKV_DEEPSPEED'] = '0' # However, sometimes DeepSpeed saves VRAM even for 1 GPU training. So you shall try it.
|
|
||||||
|
|
||||||
os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Step 2: set model details
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
EPOCH_BEGIN = 0 # begins with miniEpoch = EPOCH_BEGIN
|
|
||||||
LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the training from it?
|
|
||||||
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer
|
|
||||||
|
|
||||||
model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better)
|
|
||||||
|
|
||||||
# there is also a RWKV_HEAD_QK_DIM in model.py and model_run.py
|
|
||||||
# set it to 256, then it's using my headQK trick (a tiny attention) to improve loss
|
|
||||||
# set it to 0, then it's a pure RNN (attention-free)
|
|
||||||
|
|
||||||
if EXPRESS_PILE_MODE:
|
|
||||||
LOAD_MODEL = True
|
|
||||||
if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M':
|
|
||||||
n_layer = 12
|
|
||||||
n_embd = 768
|
|
||||||
ctx_len = 1024
|
|
||||||
elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-430M':
|
|
||||||
n_layer = 24
|
|
||||||
n_embd = 1024
|
|
||||||
ctx_len = 1024
|
|
||||||
elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-1B5':
|
|
||||||
n_layer = 24
|
|
||||||
n_embd = 2048
|
|
||||||
ctx_len = 1024
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Step 3: set batch size & learning rate etc.
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
# if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU.
|
|
||||||
batch_size = 12 * int(os.environ['RWKV_NUM_GPUS'])
|
|
||||||
assert (batch_size % int(os.environ['RWKV_NUM_GPUS']) == 0)
|
|
||||||
|
|
||||||
# By default we are using exponential LR decay.
|
|
||||||
# Here are my suggestions for training.
|
|
||||||
# Let's say you are training a L6-D512 model.
|
|
||||||
# 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until you feel like reducing LR.
|
|
||||||
# 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run.
|
|
||||||
# 3) Set lr_init = 8e-4, lr_final = 1e-5, betas = (0.9, 0.999).
|
|
||||||
# 4) Set EPOCH_BEGIN & LOAD_MODEL to load the partially-trained model. Continue the training.
|
|
||||||
#
|
|
||||||
# For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4.
|
|
||||||
|
|
||||||
lr_init = 8e-4
|
|
||||||
lr_final = 1e-5
|
|
||||||
|
|
||||||
# the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens)
|
|
||||||
n_epoch = 500
|
|
||||||
epoch_length_fixed = (10000 // batch_size) * batch_size # feel free to increase it if you have lots of GPU
|
|
||||||
|
|
||||||
# epoch_save_frequency 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ...
|
|
||||||
epoch_save_frequency = 10
|
|
||||||
epoch_save_path = 'trained-'
|
|
||||||
|
|
||||||
if EXPRESS_PILE_MODE:
|
|
||||||
if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M':
|
|
||||||
lr_init = 2e-5
|
|
||||||
else:
|
|
||||||
lr_init = 1e-5
|
|
||||||
lr_final = 1e-5
|
|
||||||
n_epoch = 100000
|
|
||||||
|
|
||||||
### misc stuffs ########################################################################################
|
|
||||||
|
|
||||||
if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients, so let's have some warmup if we load a model
|
|
||||||
warmup_tokens = 50 * ctx_len * batch_size // NUM_GPUS
|
|
||||||
else:
|
|
||||||
warmup_tokens = 0
|
|
||||||
|
|
||||||
betas = (0.9, 0.99) # set betas = (0.9, 0.999) if your model has been trained for a while
|
|
||||||
eps = 1e-8
|
|
||||||
|
|
||||||
num_workers = 1 # DataLoader worker. I only tested num_workers = 1
|
|
||||||
|
|
||||||
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
|
|
||||||
os.environ['RWKV_LOAD_MODEL'] = str(LOAD_MODEL)
|
|
||||||
MODEL_NAME = epoch_save_path + str(EPOCH_BEGIN)
|
|
||||||
|
|
||||||
if EXPRESS_PILE_MODE:
|
|
||||||
betas = (0.9, 0.999)
|
|
||||||
MODEL_NAME = EXPRESS_PILE_MODEL_NAME
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
|
|
||||||
torch.backends.cudnn.allow_tf32 = False
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
|
||||||
else:
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Load data
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print(f'loading {datafile_encoding} data... ' + datafile)
|
|
||||||
if datafile_encoding == 'binidx':
|
|
||||||
train_dataset = Dataset(MMapIndexedDataset(datafile), ctx_len, epoch_length_fixed)
|
|
||||||
elif datafile_encoding == 'numpy':
|
|
||||||
train_dataset = Dataset(np.load(datafile).astype('int'), ctx_len, epoch_length_fixed)
|
|
||||||
else:
|
|
||||||
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# Train model
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from src.trainer import Trainer, TrainerConfig
|
|
||||||
|
|
||||||
print('\nmodel', model_type, os.environ['RWKV_FLOAT_MODE'], 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
|
|
||||||
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, '\n')
|
|
||||||
|
|
||||||
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
|
|
||||||
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
|
|
||||||
warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
|
|
||||||
m_cfg = types.SimpleNamespace()
|
|
||||||
m_cfg.model_type = model_type
|
|
||||||
m_cfg.n_layer = n_layer
|
|
||||||
m_cfg.n_embd = n_embd
|
|
||||||
m_cfg.EPOCH_BEGIN = EPOCH_BEGIN
|
|
||||||
m_cfg.LOAD_MODEL = LOAD_MODEL
|
|
||||||
m_cfg.MODEL_NAME = MODEL_NAME
|
|
||||||
|
|
||||||
if os.environ['RWKV_DEEPSPEED'] == '0':
|
|
||||||
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16)
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16')
|
|
||||||
elif '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32)
|
|
||||||
else:
|
|
||||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
||||||
|
|
||||||
DEEPSPEED_CFG = {
|
|
||||||
"zero_allow_untested_optimizer":True,
|
|
||||||
"zero_optimization":{
|
|
||||||
"stage":2,
|
|
||||||
"contiguous_gradients":True,
|
|
||||||
"overlap_comm":True,
|
|
||||||
"allgather_partitions":True,
|
|
||||||
"reduce_scatter":True,
|
|
||||||
"allgather_bucket_size":200000000,
|
|
||||||
"reduce_bucket_size":200000000,
|
|
||||||
"sub_group_size":1000000000000
|
|
||||||
},
|
|
||||||
"activation_checkpointing":{
|
|
||||||
"partition_activations":False,
|
|
||||||
"cpu_checkpointing":False,
|
|
||||||
"contiguous_memory_optimization":False,
|
|
||||||
"synchronize_checkpoint_boundary":False
|
|
||||||
},
|
|
||||||
"aio":{
|
|
||||||
"block_size":1048576,
|
|
||||||
"queue_depth":8,
|
|
||||||
"single_submit":False,
|
|
||||||
"overlap_events":True,
|
|
||||||
"thread_count":1
|
|
||||||
},
|
|
||||||
"gradient_clipping": 1.0,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
}
|
|
||||||
if NUM_GPUS == 1:
|
|
||||||
DEEPSPEED_CFG['zero_optimization'] = {
|
|
||||||
"stage":1, # saves some VRAM
|
|
||||||
"contiguous_gradients":False,
|
|
||||||
"overlap_comm":False,
|
|
||||||
"allgather_partitions":False,
|
|
||||||
"reduce_scatter":False,
|
|
||||||
"allgather_bucket_size":200000000,
|
|
||||||
"reduce_bucket_size":200000000,
|
|
||||||
"sub_group_size":1000000000000
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
DEEPSPEED_CFG["fp16"] = {
|
|
||||||
"fp16": True,
|
|
||||||
"enabled": True,
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 12,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
}
|
|
||||||
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16)
|
|
||||||
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
DEEPSPEED_CFG["bf16"] = {
|
|
||||||
"enabled": True
|
|
||||||
}
|
|
||||||
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
|
|
||||||
|
|
||||||
elif '32' in os.environ['RWKV_FLOAT_MODE']:
|
|
||||||
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32)
|
|
||||||
|
|
||||||
print(trainer._strategy.config)
|
|
||||||
|
|
||||||
trainer.run(m_cfg, train_dataset, None, tconf)
|
|
||||||
@ -1,90 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
# this is for verifying the results of different models and make sure they agree with each other
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
|
|
||||||
os.environ['RWKV_RUN_DEVICE'] = 'cuda'
|
|
||||||
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from src.model_run import RWKV_RNN, RWKV_GPT
|
|
||||||
from src.model import GPT, GPTConfig
|
|
||||||
|
|
||||||
TOKEN_MODE = 'pile' # char / pile
|
|
||||||
|
|
||||||
if TOKEN_MODE == 'char':
|
|
||||||
MODEL_NAME = 'trained-1'
|
|
||||||
WORD_NAME = 'vocab' # the .json vocab (generated by train.py)
|
|
||||||
ctx_len = 1024
|
|
||||||
n_layer = 6
|
|
||||||
n_embd = 512
|
|
||||||
UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity
|
|
||||||
elif TOKEN_MODE == 'pile':
|
|
||||||
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
|
|
||||||
MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023'
|
|
||||||
ctx_len = 1024
|
|
||||||
n_layer = 12
|
|
||||||
n_embd = 768
|
|
||||||
UNKNOWN_CHAR = None
|
|
||||||
|
|
||||||
model_type = 'RWKV'
|
|
||||||
|
|
||||||
from src.utils import TOKENIZER
|
|
||||||
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
||||||
if TOKEN_MODE == 'pile':
|
|
||||||
tokenizer.vocab_size = 50277
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
|
|
||||||
|
|
||||||
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
model_train = model_train.half()
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
model_train = model_train.bfloat16()
|
|
||||||
|
|
||||||
print('loading ' + MODEL_NAME)
|
|
||||||
m2 = torch.load(MODEL_NAME + '.pth', map_location=RUN_DEVICE)
|
|
||||||
model_train.load_state_dict(m2)
|
|
||||||
|
|
||||||
model_rnn = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
|
|
||||||
model_gpt = RWKV_GPT(MODEL_NAME, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda()
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
# context = '\nIn a'
|
|
||||||
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
|
|
||||||
|
|
||||||
if TOKEN_MODE == 'char':
|
|
||||||
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
|
|
||||||
elif TOKEN_MODE == 'pile':
|
|
||||||
ctx = tokenizer.tokenizer.encode(context)
|
|
||||||
print(f'input len {len(ctx)} data {ctx}')
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print('\nRWKV-GPT output')
|
|
||||||
out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy()
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
print('\nRWKV-RNN output')
|
|
||||||
model_rnn.clear()
|
|
||||||
src_len = len(ctx)
|
|
||||||
for i in range(src_len):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
out = model_rnn.run(x)
|
|
||||||
if i < 3 or i >= src_len - 3:
|
|
||||||
print(torch.tensor(out).detach().cpu().numpy())
|
|
||||||
if i == 2:
|
|
||||||
print('...')
|
|
||||||
|
|
||||||
print('\nRWKV-train output')
|
|
||||||
out = model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().float().numpy()
|
|
||||||
print(out, '\n')
|
|
||||||
@ -1,133 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#define MIN_VALUE (-1e38)
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
|
||||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
|
||||||
F *__restrict__ const _y) {
|
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int _b = idx / C;
|
|
||||||
const int _c = idx % C;
|
|
||||||
const int _offset = _b * T * C + _c;
|
|
||||||
|
|
||||||
F u = _u[_c];
|
|
||||||
F w = _w[_c];
|
|
||||||
const F *__restrict__ const k = _k + _offset;
|
|
||||||
const F *__restrict__ const v = _v + _offset;
|
|
||||||
F *__restrict__ const y = _y + _offset;
|
|
||||||
|
|
||||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
|
||||||
F aa = 0, bb = 0, pp = MIN_VALUE;
|
|
||||||
for (int i = 0; i < T; i++) {
|
|
||||||
const int ii = i * C;
|
|
||||||
const F kk = k[ii];
|
|
||||||
const F vv = v[ii];
|
|
||||||
|
|
||||||
F ww = u + kk;
|
|
||||||
F p = max(pp, ww);
|
|
||||||
F e1 = exp(pp - p);
|
|
||||||
F e2 = exp(ww - p);
|
|
||||||
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
|
||||||
|
|
||||||
ww = w + pp;
|
|
||||||
p = max(ww, kk);
|
|
||||||
e1 = exp(ww - p);
|
|
||||||
e2 = exp(kk - p);
|
|
||||||
aa = e1 * aa + e2 * vv;
|
|
||||||
bb = e1 * bb + e2;
|
|
||||||
pp = p;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
|
||||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
|
||||||
const F *__restrict__ const _y, const F *__restrict__ const _gy,
|
|
||||||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int _b = idx / C;
|
|
||||||
const int _c = idx % C;
|
|
||||||
const int _offset = _b * T * C + _c;
|
|
||||||
|
|
||||||
F u = _u[_c];
|
|
||||||
F w = _w[_c];
|
|
||||||
const F *__restrict__ const k = _k + _offset;
|
|
||||||
const F *__restrict__ const v = _v + _offset;
|
|
||||||
const F *__restrict__ const y = _y + _offset;
|
|
||||||
const F *__restrict__ const gy = _gy + _offset;
|
|
||||||
F *__restrict__ const gk = _gk + _offset;
|
|
||||||
F *__restrict__ const gv = _gv + _offset;
|
|
||||||
|
|
||||||
F q[Tmax], r[Tmax];
|
|
||||||
|
|
||||||
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
|
||||||
for (int i = 0; i < T; i++) {
|
|
||||||
const int ii = i * C;
|
|
||||||
const F kk = k[ii];
|
|
||||||
const F vv = v[ii];
|
|
||||||
const F yy = y[ii];
|
|
||||||
|
|
||||||
F ww = u + kk;
|
|
||||||
F p = max(pp, ww);
|
|
||||||
F e1 = exp(pp - p);
|
|
||||||
F e2 = exp(ww - p);
|
|
||||||
const F qq = gy[ii] / (e1 * bb + e2);
|
|
||||||
gw += (ga - gb * yy) * e1 * qq;
|
|
||||||
gu += (vv - yy) * e2 * qq;
|
|
||||||
q[i] = qq;
|
|
||||||
r[i] = ww - p;
|
|
||||||
|
|
||||||
ww = w + pp;
|
|
||||||
p = max(ww, kk);
|
|
||||||
e1 = exp(ww - p);
|
|
||||||
e2 = exp(kk - p);
|
|
||||||
ga = e1 * (aa + ga);
|
|
||||||
gb = e1 * (bb + gb);
|
|
||||||
aa = e1 * aa + e2 * vv;
|
|
||||||
bb = e1 * bb + e2;
|
|
||||||
pp = p;
|
|
||||||
}
|
|
||||||
const int _offsetBC = _b * C + _c;
|
|
||||||
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
|
||||||
_gu[_offsetBC] = gu;
|
|
||||||
|
|
||||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
|
||||||
for (int i = T - 1; i >= 0; i--) {
|
|
||||||
const int ii = i * C;
|
|
||||||
const F kk = k[ii];
|
|
||||||
const F vv = v[ii];
|
|
||||||
const F yy = y[ii];
|
|
||||||
const F qq = q[i];
|
|
||||||
const F rr = r[i];
|
|
||||||
|
|
||||||
F e1 = qq * exp(rr);
|
|
||||||
F e2 = exp(kk + pp);
|
|
||||||
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
|
||||||
gv[ii] = e1 + e2 * aa;
|
|
||||||
|
|
||||||
const F ww = w + pp;
|
|
||||||
const F www = rr - u - kk;
|
|
||||||
const F p = max(ww, www);
|
|
||||||
e1 = exp(ww - p);
|
|
||||||
e2 = qq * exp(www - p);
|
|
||||||
aa = e1 * aa + e2;
|
|
||||||
bb = e1 * bb - e2 * yy;
|
|
||||||
pp = p;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
|
||||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
|
||||||
assert(B * C % threadsPerBlock.x == 0);
|
|
||||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
|
||||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
|
||||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
|
||||||
assert(B * C % threadsPerBlock.x == 0);
|
|
||||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
|
||||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
|
||||||
}
|
|
||||||
@ -1,132 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
#include <assert.h>
|
|
||||||
#include "ATen/ATen.h"
|
|
||||||
#define MIN_VALUE (-1e38)
|
|
||||||
typedef at::BFloat16 bf16;
|
|
||||||
|
|
||||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
|
||||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
|
||||||
bf16 *__restrict__ const _y) {
|
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int _b = idx / C;
|
|
||||||
const int _c = idx % C;
|
|
||||||
const int _offset = _b * T * C + _c;
|
|
||||||
|
|
||||||
float u = float(_u[_c]);
|
|
||||||
float w = _w[_c];
|
|
||||||
const bf16 *__restrict__ const k = _k + _offset;
|
|
||||||
const bf16 *__restrict__ const v = _v + _offset;
|
|
||||||
bf16 *__restrict__ const y = _y + _offset;
|
|
||||||
|
|
||||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
|
||||||
float aa = 0, bb = 0, pp = MIN_VALUE;
|
|
||||||
for (int i = 0; i < T; i++) {
|
|
||||||
const int ii = i * C;
|
|
||||||
const float kk = float(k[ii]);
|
|
||||||
const float vv = float(v[ii]);
|
|
||||||
|
|
||||||
float ww = u + kk;
|
|
||||||
float p = max(pp, ww);
|
|
||||||
float e1 = exp(pp - p);
|
|
||||||
float e2 = exp(ww - p);
|
|
||||||
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
|
||||||
|
|
||||||
ww = w + pp;
|
|
||||||
p = max(ww, kk);
|
|
||||||
e1 = exp(ww - p);
|
|
||||||
e2 = exp(kk - p);
|
|
||||||
aa = e1 * aa + e2 * vv;
|
|
||||||
bb = e1 * bb + e2;
|
|
||||||
pp = p;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
|
||||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
|
||||||
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
|
|
||||||
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
|
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int _b = idx / C;
|
|
||||||
const int _c = idx % C;
|
|
||||||
const int _offset = _b * T * C + _c;
|
|
||||||
|
|
||||||
float u = float(_u[_c]);
|
|
||||||
float w = _w[_c];
|
|
||||||
const bf16 *__restrict__ const k = _k + _offset;
|
|
||||||
const bf16 *__restrict__ const v = _v + _offset;
|
|
||||||
const bf16 *__restrict__ const y = _y + _offset;
|
|
||||||
const bf16 *__restrict__ const gy = _gy + _offset;
|
|
||||||
bf16 *__restrict__ const gk = _gk + _offset;
|
|
||||||
bf16 *__restrict__ const gv = _gv + _offset;
|
|
||||||
|
|
||||||
float q[Tmax], r[Tmax];
|
|
||||||
|
|
||||||
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
|
||||||
for (int i = 0; i < T; i++) {
|
|
||||||
const int ii = i * C;
|
|
||||||
const float kk = float(k[ii]);
|
|
||||||
const float vv = float(v[ii]);
|
|
||||||
const float yy = float(y[ii]);
|
|
||||||
|
|
||||||
float ww = u + kk;
|
|
||||||
float p = max(pp, ww);
|
|
||||||
float e1 = exp(pp - p);
|
|
||||||
float e2 = exp(ww - p);
|
|
||||||
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
|
||||||
gw += (ga - gb * yy) * e1 * qq;
|
|
||||||
gu += (vv - yy) * e2 * qq;
|
|
||||||
q[i] = qq;
|
|
||||||
r[i] = ww - p;
|
|
||||||
|
|
||||||
ww = w + pp;
|
|
||||||
p = max(ww, kk);
|
|
||||||
e1 = exp(ww - p);
|
|
||||||
e2 = exp(kk - p);
|
|
||||||
ga = e1 * (aa + ga);
|
|
||||||
gb = e1 * (bb + gb);
|
|
||||||
aa = e1 * aa + e2 * vv;
|
|
||||||
bb = e1 * bb + e2;
|
|
||||||
pp = p;
|
|
||||||
}
|
|
||||||
const int _offsetBC = _b * C + _c;
|
|
||||||
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
|
||||||
_gu[_offsetBC] = bf16(gu);
|
|
||||||
|
|
||||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
|
||||||
for (int i = T - 1; i >= 0; i--) {
|
|
||||||
const int ii = i * C;
|
|
||||||
const float kk = float(k[ii]);
|
|
||||||
const float vv = float(v[ii]);
|
|
||||||
const float yy = float(y[ii]);
|
|
||||||
const float qq = q[i];
|
|
||||||
const float rr = r[i];
|
|
||||||
|
|
||||||
float e1 = qq * exp(rr);
|
|
||||||
float e2 = exp(kk + pp);
|
|
||||||
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
|
||||||
gv[ii] = bf16(e1 + e2 * aa);
|
|
||||||
|
|
||||||
const float ww = w + pp;
|
|
||||||
const float www = rr - u - kk;
|
|
||||||
const float p = max(ww, www);
|
|
||||||
e1 = exp(ww - p);
|
|
||||||
e2 = qq * exp(www - p);
|
|
||||||
aa = e1 * aa + e2;
|
|
||||||
bb = e1 * bb - e2 * yy;
|
|
||||||
pp = p;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
|
||||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
|
||||||
assert(B * C % threadsPerBlock.x == 0);
|
|
||||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
|
||||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
|
||||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
|
||||||
assert(B * C % threadsPerBlock.x == 0);
|
|
||||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
|
||||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
|
||||||
}
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
|
||||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
|
||||||
|
|
||||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
|
||||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
|
||||||
}
|
|
||||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
|
||||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &forward, "wkv forward");
|
|
||||||
m.def("backward", &backward, "wkv backward");
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY(wkv, m) {
|
|
||||||
m.def("forward", forward);
|
|
||||||
m.def("backward", backward);
|
|
||||||
}
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
#include "ATen/ATen.h"
|
|
||||||
typedef at::BFloat16 bf16;
|
|
||||||
|
|
||||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
|
||||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
|
||||||
|
|
||||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
|
||||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
|
||||||
}
|
|
||||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
|
||||||
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
|
||||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
|
||||||
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &forward, "wkv forward");
|
|
||||||
m.def("backward", &backward, "wkv backward");
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY(wkv, m) {
|
|
||||||
m.def("forward", forward);
|
|
||||||
m.def("backward", backward);
|
|
||||||
}
|
|
||||||
@ -1,165 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import torch, types, os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torchvision as vision
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
print(f'loading...')
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201'
|
|
||||||
input_img = 'test/img_ae_test/test0.png'
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
class ToBinary(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x):
|
|
||||||
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return grad_output.clone() # pass-through
|
|
||||||
|
|
||||||
class R_ENCODER(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
dd = 8
|
|
||||||
self.Bxx = nn.BatchNorm2d(dd*64)
|
|
||||||
|
|
||||||
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
|
|
||||||
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
|
||||||
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B00 = nn.BatchNorm2d(dd*4)
|
|
||||||
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B10 = nn.BatchNorm2d(dd*16)
|
|
||||||
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B20 = nn.BatchNorm2d(dd*64)
|
|
||||||
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
def forward(self, img):
|
|
||||||
ACT = F.mish
|
|
||||||
|
|
||||||
x = self.CIN(img)
|
|
||||||
xx = self.Bxx(F.pixel_unshuffle(x, 8))
|
|
||||||
x = x + self.Cx1(ACT(self.Cx0(x)))
|
|
||||||
|
|
||||||
x = F.pixel_unshuffle(x, 2)
|
|
||||||
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
|
||||||
x = x + self.C03(ACT(self.C02(x)))
|
|
||||||
|
|
||||||
x = F.pixel_unshuffle(x, 2)
|
|
||||||
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
|
|
||||||
x = x + self.C13(ACT(self.C12(x)))
|
|
||||||
|
|
||||||
x = F.pixel_unshuffle(x, 2)
|
|
||||||
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
|
|
||||||
x = x + self.C23(ACT(self.C22(x)))
|
|
||||||
|
|
||||||
x = self.COUT(x + xx)
|
|
||||||
return torch.sigmoid(x)
|
|
||||||
|
|
||||||
class R_DECODER(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
dd = 8
|
|
||||||
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B00 = nn.BatchNorm2d(dd*64)
|
|
||||||
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B10 = nn.BatchNorm2d(dd*16)
|
|
||||||
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B20 = nn.BatchNorm2d(dd*4)
|
|
||||||
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
|
||||||
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
|
||||||
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
def forward(self, code):
|
|
||||||
ACT = F.mish
|
|
||||||
x = self.CIN(code)
|
|
||||||
|
|
||||||
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
|
||||||
x = x + self.C03(ACT(self.C02(x)))
|
|
||||||
x = F.pixel_shuffle(x, 2)
|
|
||||||
|
|
||||||
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
|
|
||||||
x = x + self.C13(ACT(self.C12(x)))
|
|
||||||
x = F.pixel_shuffle(x, 2)
|
|
||||||
|
|
||||||
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
|
|
||||||
x = x + self.C23(ACT(self.C22(x)))
|
|
||||||
x = F.pixel_shuffle(x, 2)
|
|
||||||
|
|
||||||
x = x + self.Cx1(ACT(self.Cx0(x)))
|
|
||||||
x = self.COUT(x)
|
|
||||||
|
|
||||||
return torch.sigmoid(x)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print(f'building model...')
|
|
||||||
args = types.SimpleNamespace()
|
|
||||||
args.my_img_bit = 13
|
|
||||||
encoder = R_ENCODER(args).eval().cuda()
|
|
||||||
decoder = R_DECODER(args).eval().cuda()
|
|
||||||
|
|
||||||
zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long()
|
|
||||||
|
|
||||||
encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth'))
|
|
||||||
decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth'))
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print(f'test image...')
|
|
||||||
img_transform = transforms.Compose([
|
|
||||||
transforms.PILToTensor(),
|
|
||||||
transforms.ConvertImageDtype(torch.float),
|
|
||||||
transforms.Resize((224, 224))
|
|
||||||
])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
img = img_transform(Image.open(input_img)).unsqueeze(0).cuda()
|
|
||||||
z = encoder(img)
|
|
||||||
z = ToBinary.apply(z)
|
|
||||||
|
|
||||||
zz = torch.sum(z.squeeze().long() * zpow, dim=0)
|
|
||||||
print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
|
|
||||||
|
|
||||||
out = decoder(z)
|
|
||||||
vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg")
|
|
||||||
@ -1,269 +0,0 @@
|
|||||||
from lib2to3.pgen2 import token
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import shutil
|
|
||||||
import struct
|
|
||||||
from functools import lru_cache
|
|
||||||
from itertools import accumulate
|
|
||||||
|
|
||||||
def print_rank_0(*message):
|
|
||||||
pass
|
|
||||||
# """If distributed is initialized print only on rank 0."""
|
|
||||||
# if torch.distributed.is_initialized():
|
|
||||||
# if torch.distributed.get_rank() == 0:
|
|
||||||
# print(*message, flush=True)
|
|
||||||
# else:
|
|
||||||
# print(*message, flush=True)
|
|
||||||
|
|
||||||
def _warmup_mmap_file(path):
|
|
||||||
pass
|
|
||||||
# with open(path, "rb") as stream:
|
|
||||||
# while stream.read(100 * 1024 * 1024):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
dtypes = {
|
|
||||||
1: np.uint8,
|
|
||||||
2: np.int8,
|
|
||||||
3: np.int16,
|
|
||||||
4: np.int32,
|
|
||||||
5: np.int64,
|
|
||||||
6: float,
|
|
||||||
7: np.double,
|
|
||||||
8: np.uint16,
|
|
||||||
}
|
|
||||||
|
|
||||||
def code(dtype):
|
|
||||||
for k in dtypes.keys():
|
|
||||||
if dtypes[k] == dtype:
|
|
||||||
return k
|
|
||||||
raise ValueError(dtype)
|
|
||||||
|
|
||||||
def index_file_path(prefix_path):
|
|
||||||
return prefix_path + ".idx"
|
|
||||||
|
|
||||||
def data_file_path(prefix_path):
|
|
||||||
return prefix_path + ".bin"
|
|
||||||
|
|
||||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
|
||||||
class Index(object):
|
|
||||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def writer(cls, path, dtype):
|
|
||||||
class _Writer(object):
|
|
||||||
def __enter__(self):
|
|
||||||
self._file = open(path, "wb")
|
|
||||||
|
|
||||||
# Write Magic string so we can check the file format then opening it again.
|
|
||||||
self._file.write(cls._HDR_MAGIC)
|
|
||||||
# Write version number
|
|
||||||
# Little endian unsigned 64 Bit integer
|
|
||||||
self._file.write(struct.pack("<Q", 1))
|
|
||||||
# Little endian unsigned 8 Bit integer
|
|
||||||
self._file.write(struct.pack("<B", code(dtype)))
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_pointers(sizes):
|
|
||||||
dtype_size = dtype().itemsize
|
|
||||||
address = 0
|
|
||||||
pointers = []
|
|
||||||
|
|
||||||
for size in sizes:
|
|
||||||
pointers.append(address)
|
|
||||||
address += size * dtype_size
|
|
||||||
|
|
||||||
return pointers
|
|
||||||
|
|
||||||
def write(self, sizes, doc_idx):
|
|
||||||
pointers = self._get_pointers(sizes)
|
|
||||||
|
|
||||||
# Little endian unsigned 64 Bit integer
|
|
||||||
self._file.write(struct.pack("<Q", len(sizes)))
|
|
||||||
# Little endian unsigned 64 Bit integer
|
|
||||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
|
||||||
|
|
||||||
sizes = np.array(sizes, dtype=np.int32)
|
|
||||||
self._file.write(sizes.tobytes(order="C"))
|
|
||||||
del sizes
|
|
||||||
|
|
||||||
pointers = np.array(pointers, dtype=np.int64)
|
|
||||||
self._file.write(pointers.tobytes(order="C"))
|
|
||||||
del pointers
|
|
||||||
|
|
||||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
|
||||||
self._file.write(doc_idx.tobytes(order="C"))
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self._file.close()
|
|
||||||
|
|
||||||
return _Writer()
|
|
||||||
|
|
||||||
def __init__(self, path, skip_warmup=False):
|
|
||||||
with open(path, "rb") as stream:
|
|
||||||
magic_test = stream.read(9)
|
|
||||||
assert self._HDR_MAGIC == magic_test, (
|
|
||||||
"Index file doesn't match expected format. "
|
|
||||||
"Make sure that --dataset-impl is configured properly."
|
|
||||||
)
|
|
||||||
# Little endian unsigned 64 Bit integer
|
|
||||||
version = struct.unpack("<Q", stream.read(8))
|
|
||||||
assert (1,) == version
|
|
||||||
|
|
||||||
# Little endian unsigned 8 Bit integer
|
|
||||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
|
||||||
self._dtype = dtypes[dtype_code]
|
|
||||||
self._dtype_size = self._dtype().itemsize
|
|
||||||
|
|
||||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
|
||||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
|
||||||
offset = stream.tell()
|
|
||||||
|
|
||||||
if not skip_warmup:
|
|
||||||
print_rank_0(" warming up index mmap file...")
|
|
||||||
_warmup_mmap_file(path)
|
|
||||||
|
|
||||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
|
||||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
|
||||||
print_rank_0(" reading sizes...")
|
|
||||||
self._sizes = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
|
||||||
)
|
|
||||||
print_rank_0(" reading pointers...")
|
|
||||||
self._pointers = np.frombuffer(
|
|
||||||
self._bin_buffer,
|
|
||||||
dtype=np.int64,
|
|
||||||
count=self._len,
|
|
||||||
offset=offset + self._sizes.nbytes,
|
|
||||||
)
|
|
||||||
print_rank_0(" reading document index...")
|
|
||||||
self._doc_idx = np.frombuffer(
|
|
||||||
self._bin_buffer,
|
|
||||||
dtype=np.int64,
|
|
||||||
count=self._doc_count,
|
|
||||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self._bin_buffer_mmap._mmap.close()
|
|
||||||
del self._bin_buffer_mmap
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sizes(self):
|
|
||||||
return self._sizes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def doc_idx(self):
|
|
||||||
return self._doc_idx
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
|
||||||
def __getitem__(self, i):
|
|
||||||
return self._pointers[i], self._sizes[i]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._len
|
|
||||||
|
|
||||||
def __init__(self, path, skip_warmup=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._path = None
|
|
||||||
self._index = None
|
|
||||||
self._bin_buffer = None
|
|
||||||
|
|
||||||
self._do_init(path, skip_warmup)
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
return self._path
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
self._do_init(state)
|
|
||||||
|
|
||||||
def _do_init(self, path, skip_warmup):
|
|
||||||
self._path = path
|
|
||||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
|
||||||
|
|
||||||
if not skip_warmup:
|
|
||||||
print_rank_0(" warming up data mmap file...")
|
|
||||||
_warmup_mmap_file(data_file_path(self._path))
|
|
||||||
print_rank_0(" creating numpy buffer of mmap...")
|
|
||||||
self._bin_buffer_mmap = np.memmap(
|
|
||||||
data_file_path(self._path), mode="r", order="C"
|
|
||||||
)
|
|
||||||
print_rank_0(" creating memory view of numpy buffer...")
|
|
||||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self._bin_buffer_mmap._mmap.close()
|
|
||||||
del self._bin_buffer_mmap
|
|
||||||
del self._index
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._index)
|
|
||||||
|
|
||||||
# @lru_cache(maxsize=8)
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
if isinstance(idx, int):
|
|
||||||
ptr, size = self._index[idx]
|
|
||||||
np_array = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
|
||||||
)
|
|
||||||
return np_array
|
|
||||||
elif isinstance(idx, slice):
|
|
||||||
start, stop, step = idx.indices(len(self))
|
|
||||||
if step != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Slices into indexed_dataset must be contiguous")
|
|
||||||
ptr = self._index._pointers[start]
|
|
||||||
sizes = self._index._sizes[idx]
|
|
||||||
offsets = list(accumulate(sizes))
|
|
||||||
total_size = sum(sizes)
|
|
||||||
np_array = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
|
||||||
)
|
|
||||||
sents = np.split(np_array, offsets[:-1])
|
|
||||||
return sents
|
|
||||||
|
|
||||||
def get(self, idx, offset=0, length=None):
|
|
||||||
"""Retrieves a single item from the dataset with the option to only
|
|
||||||
return a portion of the item.
|
|
||||||
|
|
||||||
get(idx) is the same as [idx] but get() does not support slicing.
|
|
||||||
"""
|
|
||||||
ptr, size = self._index[idx]
|
|
||||||
if length is None:
|
|
||||||
length = size - offset
|
|
||||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
|
||||||
np_array = np.frombuffer(
|
|
||||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
|
||||||
)
|
|
||||||
return np_array
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sizes(self):
|
|
||||||
return self._index.sizes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def doc_idx(self):
|
|
||||||
return self._index.doc_idx
|
|
||||||
|
|
||||||
def get_doc_idx(self):
|
|
||||||
return self._index._doc_idx
|
|
||||||
|
|
||||||
def set_doc_idx(self, doc_idx_):
|
|
||||||
self._index._doc_idx = doc_idx_
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_prefetch(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def exists(path):
|
|
||||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
|
||||||
data_file_path(path)
|
|
||||||
)
|
|
||||||
@ -1,240 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import json, math, random, os, sys
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from pytorch_lightning.utilities import rank_zero_info
|
|
||||||
from .binidx import MMapIndexedDataset
|
|
||||||
from .utils import MaybeIsPrime
|
|
||||||
|
|
||||||
|
|
||||||
class MyDataset(Dataset):
|
|
||||||
def __init__(self, args):
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
if args.data_type == "binidx":
|
|
||||||
self.vocab_size = args.vocab_size
|
|
||||||
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
|
|
||||||
|
|
||||||
if args.my_pile_version == 1:
|
|
||||||
self.data = MMapIndexedDataset(args.data_file)
|
|
||||||
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
|
|
||||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
|
||||||
else:
|
|
||||||
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
|
|
||||||
data_list = [i.strip().split(' ') for i in data_list]
|
|
||||||
self.data = []
|
|
||||||
self.data_size = int(data_list[-1][-1])
|
|
||||||
rank_zero_info(f"Data has {self.data_size} chunks.")
|
|
||||||
for d in data_list:
|
|
||||||
data = MMapIndexedDataset(d[0])
|
|
||||||
data_size = len(data._bin_buffer) // data._index._dtype_size
|
|
||||||
assert (data_size - args.ctx_len) == int(d[1])
|
|
||||||
self.data += [[int(d[-1]), int(d[1]), data]]
|
|
||||||
# rank_zero_info(self.data)
|
|
||||||
|
|
||||||
if args.my_qa_mask > 0:
|
|
||||||
self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
|
|
||||||
# self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
|
|
||||||
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
|
||||||
|
|
||||||
if args.my_pile_stage > 0:
|
|
||||||
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
|
||||||
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
|
||||||
assert self.samples_per_epoch == 40320
|
|
||||||
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
|
||||||
dataset_slot = self.data_size // args.ctx_len
|
|
||||||
if args.my_pile_stage != 4:
|
|
||||||
assert MaybeIsPrime(args.magic_prime)
|
|
||||||
assert args.magic_prime % 3 == 2
|
|
||||||
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
|
|
||||||
elif args.data_type == "numpy":
|
|
||||||
self.data = np.load(args.data_file).astype("int")
|
|
||||||
self.vocab_size = args.vocab_size
|
|
||||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
|
||||||
self.data_size = len(self.data)
|
|
||||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
|
||||||
elif args.data_type == "uint16":
|
|
||||||
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
|
|
||||||
self.vocab_size = args.vocab_size
|
|
||||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
|
||||||
self.data_size = self.data.shape[0]
|
|
||||||
rank_zero_info(f"Data has {self.data_size} samples.")
|
|
||||||
elif args.data_type == "wds_img":
|
|
||||||
self.vocab_size = -1
|
|
||||||
self.data_size = -1
|
|
||||||
self.data = None
|
|
||||||
self.error_count = 0
|
|
||||||
else:
|
|
||||||
if args.data_type == "dummy":
|
|
||||||
rank_zero_info("Building dummy data...")
|
|
||||||
self.data = ""
|
|
||||||
for i in range(100000):
|
|
||||||
aa = (i) % 10000
|
|
||||||
bb = (i * i) % 10000
|
|
||||||
cc = aa + bb
|
|
||||||
self.data += f".{aa}+{bb}={cc}."
|
|
||||||
else:
|
|
||||||
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
|
||||||
rank_zero_info("Building token list...")
|
|
||||||
unique = sorted(list(set(self.data)))
|
|
||||||
self.vocab_size = len(unique)
|
|
||||||
# rank_zero_info()
|
|
||||||
# for u in unique:
|
|
||||||
# print(u, end=' ')
|
|
||||||
# rank_zero_info('\n\n')
|
|
||||||
xx = 0
|
|
||||||
xxObj = {}
|
|
||||||
for u in unique:
|
|
||||||
xxObj[xx] = u
|
|
||||||
xx += 1
|
|
||||||
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
|
||||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
|
||||||
self.data_size = len(self.data)
|
|
||||||
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
|
|
||||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
|
||||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.args.epoch_steps * self.args.micro_bsz
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
args = self.args
|
|
||||||
rank = self.global_rank
|
|
||||||
epoch = self.real_epoch
|
|
||||||
world_size = self.world_size
|
|
||||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
|
||||||
|
|
||||||
if args.data_type == "wds_img":
|
|
||||||
def init_wds(self, bias=0):
|
|
||||||
def identity(x):
|
|
||||||
return x
|
|
||||||
import webdataset as wds
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
# img_transform = transforms.Compose(
|
|
||||||
# [transforms.CenterCrop(256)]
|
|
||||||
# )
|
|
||||||
img_transform = transforms.Compose([
|
|
||||||
transforms.CenterCrop(512),
|
|
||||||
transforms.Resize((args.my_img_size))
|
|
||||||
])
|
|
||||||
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
|
|
||||||
for pp in self.data_raw.pipeline:
|
|
||||||
if 'Resampled' in str(pp):
|
|
||||||
pp.deterministic = True
|
|
||||||
def worker_seed():
|
|
||||||
return rank*100000+epoch+bias*1e9
|
|
||||||
pp.worker_seed = worker_seed
|
|
||||||
self.data = iter(self.data_raw)
|
|
||||||
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
|
|
||||||
if self.data == None:
|
|
||||||
init_wds(self)
|
|
||||||
trial = 0
|
|
||||||
while trial < 10:
|
|
||||||
try:
|
|
||||||
dd = next(self.data) # jpg, json, txt
|
|
||||||
break
|
|
||||||
except:
|
|
||||||
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
|
|
||||||
self.error_count += 1
|
|
||||||
init_wds(self, self.error_count)
|
|
||||||
trial += 1
|
|
||||||
pass
|
|
||||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
|
|
||||||
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
|
|
||||||
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
|
|
||||||
return dd[0], dd[2]
|
|
||||||
else:
|
|
||||||
if args.data_type == "uint16":
|
|
||||||
i = np.random.randint(0, self.data_size-1)
|
|
||||||
dix = self.data[i]
|
|
||||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
|
||||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
|
||||||
else:
|
|
||||||
ctx_len = args.ctx_len
|
|
||||||
req_len = ctx_len + 1
|
|
||||||
magic_prime = args.magic_prime
|
|
||||||
data = self.data
|
|
||||||
|
|
||||||
if args.my_pile_stage > 0 and args.my_pile_stage != 4:
|
|
||||||
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
|
|
||||||
|
|
||||||
if args.my_qa_mask > 0:
|
|
||||||
ii_orig = ii
|
|
||||||
if ii % 2 == 0:
|
|
||||||
ii = -1
|
|
||||||
data = self.data_pile
|
|
||||||
else:
|
|
||||||
ii = ii // 2
|
|
||||||
if ii < 0:
|
|
||||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
|
||||||
else:
|
|
||||||
factor = (math.sqrt(5) - 1) / 2
|
|
||||||
factor = int(magic_prime * factor)
|
|
||||||
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
|
|
||||||
i = i + args.my_pile_shift
|
|
||||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
|
||||||
elif args.my_pile_stage == 4:
|
|
||||||
# cheat: pick a random spot in dataset
|
|
||||||
if args.my_pile_version == 1:
|
|
||||||
i = np.random.randint(0, self.data_size - req_len)
|
|
||||||
else:
|
|
||||||
i = np.random.randint(0, self.data_size)
|
|
||||||
else:
|
|
||||||
# cheat: pick a random spot in dataset
|
|
||||||
i = np.random.randint(0, self.data_size - req_len)
|
|
||||||
|
|
||||||
if args.data_type == "binidx":
|
|
||||||
if args.my_pile_version == 1:
|
|
||||||
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
|
||||||
else:
|
|
||||||
# self.data : cutoff, chunk_count, data
|
|
||||||
for j in range(len(data)):
|
|
||||||
if i < data[j][0]:
|
|
||||||
ii = i
|
|
||||||
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
|
|
||||||
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
|
|
||||||
# print(ii, j, i)
|
|
||||||
break
|
|
||||||
elif args.data_type == "numpy":
|
|
||||||
dix = data[i : i + req_len]
|
|
||||||
else:
|
|
||||||
dix = [self.stoi[s] for s in data[i : i + req_len]]
|
|
||||||
|
|
||||||
if args.my_qa_mask == 1:
|
|
||||||
if data == self.data_pile:
|
|
||||||
z = [1] * ctx_len
|
|
||||||
else:
|
|
||||||
z = [0] * ctx_len
|
|
||||||
z_sum = 0
|
|
||||||
isGood = False
|
|
||||||
for i in range(3, ctx_len):
|
|
||||||
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
|
|
||||||
isGood = True
|
|
||||||
if dix[i] == 0:
|
|
||||||
isGood = False
|
|
||||||
if isGood:
|
|
||||||
z[i] = 1
|
|
||||||
z_sum += 1
|
|
||||||
if z_sum == 0:
|
|
||||||
z = [1] * ctx_len
|
|
||||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
|
||||||
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
|
|
||||||
z = torch.tensor(z, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
|
||||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
|
||||||
|
|
||||||
# if ii_orig < 50:
|
|
||||||
# # if rank == 1:
|
|
||||||
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
|
|
||||||
# else:
|
|
||||||
# exit(0)
|
|
||||||
|
|
||||||
if args.my_qa_mask == 1:
|
|
||||||
return x, y, z
|
|
||||||
|
|
||||||
return x, y
|
|
||||||
@ -1,610 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import os, math, gc, importlib
|
|
||||||
import torch
|
|
||||||
# torch._C._jit_set_profiling_executor(True)
|
|
||||||
# torch._C._jit_set_profiling_mode(True)
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
||||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
||||||
if importlib.util.find_spec('deepspeed'):
|
|
||||||
import deepspeed
|
|
||||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
|
||||||
|
|
||||||
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
|
|
||||||
|
|
||||||
try:
|
|
||||||
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
|
|
||||||
except:
|
|
||||||
os.environ["RWKV_MY_TESTING"] = ''
|
|
||||||
|
|
||||||
def __nop(ob):
|
|
||||||
return ob
|
|
||||||
|
|
||||||
|
|
||||||
MyModule = nn.Module
|
|
||||||
MyFunction = __nop
|
|
||||||
if os.environ["RWKV_JIT_ON"] == "1":
|
|
||||||
MyModule = torch.jit.ScriptModule
|
|
||||||
MyFunction = torch.jit.script_method
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# CUDA Kernel
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
|
||||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
|
|
||||||
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
||||||
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
|
||||||
class WKV(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, B, T, C, w, u, k, v):
|
|
||||||
ctx.B = B
|
|
||||||
ctx.T = T
|
|
||||||
ctx.C = C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 32) == 0
|
|
||||||
w = -torch.exp(w.float().contiguous())
|
|
||||||
u = u.contiguous()
|
|
||||||
k = k.contiguous()
|
|
||||||
v = v.contiguous()
|
|
||||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
|
||||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
|
||||||
ctx.save_for_backward(w, u, k, v, y)
|
|
||||||
return y
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gy):
|
|
||||||
B = ctx.B
|
|
||||||
T = ctx.T
|
|
||||||
C = ctx.C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 32) == 0
|
|
||||||
w, u, k, v, y = ctx.saved_tensors
|
|
||||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
|
||||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
|
||||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
|
||||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
|
||||||
gw = torch.sum(gw, dim=0)
|
|
||||||
gu = torch.sum(gu, dim=0)
|
|
||||||
return (None, None, None, gw, gu, gk, gv)
|
|
||||||
else:
|
|
||||||
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
|
||||||
class WKV(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, B, T, C, w, u, k, v):
|
|
||||||
ctx.B = B
|
|
||||||
ctx.T = T
|
|
||||||
ctx.C = C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 32) == 0
|
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
||||||
w = -torch.exp(w.contiguous())
|
|
||||||
u = u.contiguous()
|
|
||||||
k = k.contiguous()
|
|
||||||
v = v.contiguous()
|
|
||||||
else:
|
|
||||||
w = -torch.exp(w.float().contiguous())
|
|
||||||
u = u.float().contiguous()
|
|
||||||
k = k.float().contiguous()
|
|
||||||
v = v.float().contiguous()
|
|
||||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
|
||||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
|
||||||
ctx.save_for_backward(w, u, k, v, y)
|
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
||||||
return y
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
||||||
return y.half()
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
||||||
return y.bfloat16()
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gy):
|
|
||||||
B = ctx.B
|
|
||||||
T = ctx.T
|
|
||||||
C = ctx.C
|
|
||||||
assert T <= T_MAX
|
|
||||||
assert B * C % min(C, 32) == 0
|
|
||||||
w, u, k, v, y = ctx.saved_tensors
|
|
||||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
|
||||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
|
||||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
|
||||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
|
||||||
else:
|
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
|
|
||||||
gw = torch.sum(gw, dim=0)
|
|
||||||
gu = torch.sum(gu, dim=0)
|
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
||||||
return (None, None, None, gw, gu, gk, gv)
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
||||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
||||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
|
||||||
|
|
||||||
|
|
||||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
|
||||||
return WKV.apply(B, T, C, w, u, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV_TimeMix(MyModule):
|
|
||||||
def __init__(self, args, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.ctx_len = args.ctx_len
|
|
||||||
self.n_embd = args.n_embd
|
|
||||||
|
|
||||||
with torch.no_grad(): # fancy init
|
|
||||||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
|
||||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
||||||
ddd = torch.ones(1, 1, args.n_embd)
|
|
||||||
for i in range(args.n_embd):
|
|
||||||
ddd[0, 0, i] = i / args.n_embd
|
|
||||||
|
|
||||||
# fancy time_decay
|
|
||||||
decay_speed = torch.ones(args.dim_att)
|
|
||||||
for h in range(args.dim_att):
|
|
||||||
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
|
||||||
self.time_decay = nn.Parameter(decay_speed)
|
|
||||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
|
||||||
|
|
||||||
# fancy time_first
|
|
||||||
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
|
||||||
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
|
||||||
|
|
||||||
# fancy time_mix
|
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
||||||
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
|
||||||
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
|
||||||
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
|
||||||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
|
||||||
|
|
||||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
|
||||||
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
||||||
d_qkv = args.n_embd // 16
|
|
||||||
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
|
||||||
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
|
||||||
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
|
|
||||||
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
|
|
||||||
with torch.no_grad():
|
|
||||||
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
||||||
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
||||||
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
||||||
|
|
||||||
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
|
||||||
@MyFunction
|
|
||||||
def jit_func(self, x):
|
|
||||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
k = self.key(xk)
|
|
||||||
v = self.value(xv)
|
|
||||||
r = self.receptance(xr)
|
|
||||||
sr = torch.sigmoid(r)
|
|
||||||
return sr, k, v
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
||||||
sr, k, v = self.jit_func(x)
|
|
||||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
|
||||||
return self.output(rwkv)
|
|
||||||
|
|
||||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
|
||||||
@MyFunction
|
|
||||||
def QKV(self, q, k, v):
|
|
||||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
||||||
att = att.masked_fill(self.att_mask == 0, float('-inf'))
|
|
||||||
att = F.softmax(att, dim = -1)
|
|
||||||
x = att @ v
|
|
||||||
return x
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def jit_funcQKV(self, x):
|
|
||||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
|
|
||||||
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
|
|
||||||
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
|
|
||||||
k = self.key(xk)
|
|
||||||
v = self.value(xv)
|
|
||||||
r = self.receptance(xr)
|
|
||||||
sr = torch.sigmoid(r)
|
|
||||||
qq = self.qq(xqq)
|
|
||||||
kk = self.kk(xkk)
|
|
||||||
vv = self.vv(xvv)
|
|
||||||
return sr, k, v, qq, kk, vv
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
|
||||||
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
|
||||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
|
||||||
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
|
||||||
return rwkv
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
class RWKV_ChannelMix(MyModule):
|
|
||||||
def __init__(self, args, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
|
|
||||||
with torch.no_grad(): # fancy init of time_mix
|
|
||||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
||||||
ddd = torch.ones(1, 1, args.n_embd)
|
|
||||||
for i in range(args.n_embd):
|
|
||||||
ddd[0, 0, i] = i / args.n_embd
|
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
||||||
|
|
||||||
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
||||||
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
||||||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def forward(self, x):
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
k = self.key(xk)
|
|
||||||
k = torch.square(torch.relu(k))
|
|
||||||
kv = self.value(k)
|
|
||||||
return torch.sigmoid(self.receptance(xr)) * kv
|
|
||||||
|
|
||||||
class MishGLU(MyModule):
|
|
||||||
def __init__(self, args, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
|
||||||
|
|
||||||
x = torch.ones(1, 1, args.n_embd)
|
|
||||||
for i in range(args.n_embd):
|
|
||||||
x[0, 0, i] = i / args.n_embd
|
|
||||||
|
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
||||||
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
||||||
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
||||||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def forward(self, x):
|
|
||||||
xx = self.time_shift(x)
|
|
||||||
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
||||||
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
||||||
a = self.aa(xa)
|
|
||||||
b = self.bb(xb)
|
|
||||||
return self.value(a * F.mish(b))
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
# The RWKV Model with our blocks
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, args, layer_id):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.layer_id = layer_id
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(args.n_embd)
|
|
||||||
self.ln2 = nn.LayerNorm(args.n_embd)
|
|
||||||
|
|
||||||
if self.layer_id == 0:
|
|
||||||
self.ln0 = nn.LayerNorm(args.n_embd)
|
|
||||||
if args.my_pos_emb > 0:
|
|
||||||
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
|
||||||
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
|
||||||
|
|
||||||
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
|
||||||
self.ffnPre = RWKV_ChannelMix(args, 0)
|
|
||||||
else:
|
|
||||||
self.att = RWKV_TimeMix(args, layer_id)
|
|
||||||
|
|
||||||
if 'g' in os.environ["RWKV_MY_TESTING"]:
|
|
||||||
self.ffn = MishGLU(args, layer_id)
|
|
||||||
else:
|
|
||||||
self.ffn = RWKV_ChannelMix(args, layer_id)
|
|
||||||
|
|
||||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
||||||
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
|
||||||
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
||||||
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
||||||
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
||||||
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
||||||
|
|
||||||
def forward(self, x, x_emb=None):
|
|
||||||
args = self.args
|
|
||||||
B, T, C = x.size()
|
|
||||||
if self.layer_id == 0:
|
|
||||||
x = self.ln0(x)
|
|
||||||
if args.my_pos_emb > 0:
|
|
||||||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
|
||||||
x = x + pos_emb
|
|
||||||
|
|
||||||
if self.layer_id == 0 and args.pre_ffn > 0:
|
|
||||||
x = x + self.ffnPre(self.ln1(x))
|
|
||||||
else:
|
|
||||||
x = x + self.att(self.ln1(x))
|
|
||||||
x = x + self.ffn(self.ln2(x))
|
|
||||||
|
|
||||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
||||||
xx = self.tiny_ln(x)
|
|
||||||
q = self.tiny_q(xx)[:, :T, :]
|
|
||||||
k = self.tiny_k(xx)[:, :T, :]
|
|
||||||
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
|
||||||
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
|
||||||
x = x + c @ self.tiny_v(x_emb)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class L2Wrap(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, loss, y):
|
|
||||||
ctx.save_for_backward(y)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
y = ctx.saved_tensors[0]
|
|
||||||
# to encourage the logits to be close to 0
|
|
||||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
|
||||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
|
||||||
gy = torch.zeros_like(y)
|
|
||||||
gy.scatter_(-1, ids, maxx * factor)
|
|
||||||
return (grad_output, gy)
|
|
||||||
|
|
||||||
|
|
||||||
class RWKV(pl.LightningModule):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
if not hasattr(args, 'dim_att'):
|
|
||||||
args.dim_att = args.n_embd
|
|
||||||
if not hasattr(args, 'dim_ffn'):
|
|
||||||
args.dim_ffn = args.n_embd * 4
|
|
||||||
if not hasattr(args, 'tiny_att_layer'):
|
|
||||||
args.tiny_att_layer = -1
|
|
||||||
if not hasattr(args, 'tiny_att_dim'):
|
|
||||||
args.tiny_att_dim = -1
|
|
||||||
|
|
||||||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
|
||||||
|
|
||||||
self.ln_out = nn.LayerNorm(args.n_embd)
|
|
||||||
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
if args.head_qk > 0:
|
|
||||||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
|
||||||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
|
||||||
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
args = self.args
|
|
||||||
if args.layerwise_lr > 0:
|
|
||||||
lr_1x = set()
|
|
||||||
lr_2x = set()
|
|
||||||
lr_3x = set()
|
|
||||||
for n, p in self.named_parameters():
|
|
||||||
if "time_mix" in n:
|
|
||||||
if args.my_pile_stage == 2:
|
|
||||||
lr_2x.add(n)
|
|
||||||
else:
|
|
||||||
lr_1x.add(n)
|
|
||||||
elif "time_decay" in n:
|
|
||||||
if args.my_pile_stage == 2:
|
|
||||||
lr_3x.add(n)
|
|
||||||
else:
|
|
||||||
lr_2x.add(n)
|
|
||||||
elif "time_first" in n:
|
|
||||||
lr_3x.add(n)
|
|
||||||
else:
|
|
||||||
lr_1x.add(n)
|
|
||||||
lr_1x = sorted(list(lr_1x))
|
|
||||||
lr_2x = sorted(list(lr_2x))
|
|
||||||
lr_3x = sorted(list(lr_3x))
|
|
||||||
# print('1x', lr_1x)
|
|
||||||
# print('2x', lr_2x)
|
|
||||||
# print('3x', lr_3x)
|
|
||||||
param_dict = {n: p for n, p in self.named_parameters()}
|
|
||||||
if args.my_pile_stage == 2:
|
|
||||||
optim_groups = [
|
|
||||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
|
||||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
|
||||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
optim_groups = [
|
|
||||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
|
||||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
|
||||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
optim_groups = [
|
|
||||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.deepspeed_offload:
|
|
||||||
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
|
||||||
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
|
||||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def deepspeed_offload(self) -> bool:
|
|
||||||
strategy = self.trainer.strategy
|
|
||||||
if isinstance(strategy, DeepSpeedStrategy):
|
|
||||||
cfg = strategy.config["zero_optimization"]
|
|
||||||
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def forward(self, idx):
|
|
||||||
args = self.args
|
|
||||||
B, T = idx.size()
|
|
||||||
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
|
||||||
|
|
||||||
x = self.emb(idx)
|
|
||||||
x_emb = x
|
|
||||||
|
|
||||||
if args.tiny_att_dim > 0:
|
|
||||||
for block in self.blocks:
|
|
||||||
if args.grad_cp == 1:
|
|
||||||
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
|
||||||
else:
|
|
||||||
x = block(x, x_emb)
|
|
||||||
else:
|
|
||||||
for block in self.blocks:
|
|
||||||
if args.grad_cp == 1:
|
|
||||||
x = deepspeed.checkpointing.checkpoint(block, x)
|
|
||||||
else:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
x = self.ln_out(x)
|
|
||||||
|
|
||||||
if args.head_qk > 0:
|
|
||||||
q = self.head_q(x)[:, :T, :]
|
|
||||||
k = self.head_k(x)[:, :T, :]
|
|
||||||
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
|
||||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
||||||
|
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
||||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
||||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
||||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
|
||||||
|
|
||||||
x = self.head(x) + c
|
|
||||||
else:
|
|
||||||
x = self.head(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
args = self.args
|
|
||||||
if args.my_qa_mask != 1:
|
|
||||||
idx, targets = batch
|
|
||||||
logits = self(idx)
|
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
||||||
else:
|
|
||||||
idx, targets, mask = batch
|
|
||||||
mask = mask.view(-1)
|
|
||||||
sum_mask = torch.sum(mask).item()
|
|
||||||
# if sum_mask == 0:
|
|
||||||
# return torch.tensor([0.0], requires_grad=True)
|
|
||||||
|
|
||||||
logits = self(idx)
|
|
||||||
if sum_mask == mask.shape[0]:
|
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
||||||
# print('rank', self.global_rank, 'loss', loss.item())
|
|
||||||
else:
|
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
|
||||||
# loss_raw = loss
|
|
||||||
loss = torch.sum(loss * mask) / sum_mask
|
|
||||||
|
|
||||||
# torch.set_printoptions(threshold=10000)
|
|
||||||
# if True: #self.global_rank == 1:
|
|
||||||
# tmp = ''
|
|
||||||
# sss = 0
|
|
||||||
# ccc = 0
|
|
||||||
# for i in range(mask.shape[0]):
|
|
||||||
# if mask[i] > 0:
|
|
||||||
# tmp += str(idx.view(-1)[i].item()) + ','
|
|
||||||
# sss += loss_raw.view(-1)[i].float().item()
|
|
||||||
# ccc += 1
|
|
||||||
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
|
||||||
|
|
||||||
return L2Wrap.apply(loss, logits)
|
|
||||||
|
|
||||||
def training_step_end(self, batch_parts):
|
|
||||||
all = self.all_gather(batch_parts)
|
|
||||||
if self.trainer.is_global_zero:
|
|
||||||
self.trainer.my_loss_all = all
|
|
||||||
|
|
||||||
def generate_init_weight(self):
|
|
||||||
print(
|
|
||||||
f"""
|
|
||||||
############################################################################
|
|
||||||
#
|
|
||||||
# Init model weight (slow for large models)...
|
|
||||||
#
|
|
||||||
############################################################################
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
m = {}
|
|
||||||
for n in self.state_dict():
|
|
||||||
p = self.state_dict()[n]
|
|
||||||
shape = p.shape
|
|
||||||
|
|
||||||
gain = 1.0
|
|
||||||
scale = 1.0
|
|
||||||
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
|
||||||
m[n] = p
|
|
||||||
else:
|
|
||||||
if n == "emb.weight":
|
|
||||||
scale = -1 * self.args.lr_init
|
|
||||||
else:
|
|
||||||
if shape[0] > shape[1]:
|
|
||||||
gain = math.sqrt(shape[0] / shape[1])
|
|
||||||
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
|
|
||||||
if kk in n:
|
|
||||||
scale = 0
|
|
||||||
if n == "head.weight":
|
|
||||||
scale = 0.5
|
|
||||||
if "head_k." in n:
|
|
||||||
scale = 0.1
|
|
||||||
if "head_q." in n:
|
|
||||||
scale = 0
|
|
||||||
|
|
||||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
|
||||||
|
|
||||||
if self.args.accelerator.upper() == "GPU":
|
|
||||||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
|
||||||
else:
|
|
||||||
m[n] = torch.empty((shape[0], shape[1]))
|
|
||||||
|
|
||||||
if scale == 0:
|
|
||||||
nn.init.zeros_(m[n])
|
|
||||||
elif scale < 0:
|
|
||||||
nn.init.uniform_(m[n], a=scale, b=-scale)
|
|
||||||
else:
|
|
||||||
nn.init.orthogonal_(m[n], gain=gain * scale)
|
|
||||||
|
|
||||||
m[n] = m[n].cpu()
|
|
||||||
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
||||||
m[n] = m[n].half()
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
||||||
m[n] = m[n].bfloat16()
|
|
||||||
|
|
||||||
# if n == "emb.weight":
|
|
||||||
# print(m[n])
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return m
|
|
||||||
@ -1,446 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os, math, gc
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision as vision
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
||||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
||||||
import deepspeed
|
|
||||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
|
||||||
# from pytorch_msssim import MS_SSIM
|
|
||||||
|
|
||||||
def __nop(ob):
|
|
||||||
return ob
|
|
||||||
MyModule = torch.jit.ScriptModule
|
|
||||||
# MyFunction = __nop
|
|
||||||
MyFunction = torch.jit.script_method
|
|
||||||
|
|
||||||
import clip
|
|
||||||
from transformers import CLIPModel
|
|
||||||
|
|
||||||
class L2pooling(nn.Module):
|
|
||||||
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
|
|
||||||
super(L2pooling, self).__init__()
|
|
||||||
self.padding = (filter_size - 2) // 2
|
|
||||||
self.stride = stride
|
|
||||||
self.channels = channels
|
|
||||||
a = np.hanning(filter_size)[1:-1]
|
|
||||||
g = torch.Tensor(a[:, None] * a[None, :])
|
|
||||||
g = g / torch.sum(g)
|
|
||||||
self.register_buffer(
|
|
||||||
"filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
input = input**2
|
|
||||||
out = F.conv2d(
|
|
||||||
input,
|
|
||||||
self.filter,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
groups=input.shape[1],
|
|
||||||
)
|
|
||||||
return (out + 1e-12).sqrt()
|
|
||||||
|
|
||||||
|
|
||||||
class DISTS(torch.nn.Module):
|
|
||||||
def __init__(self, load_weights=True):
|
|
||||||
super(DISTS, self).__init__()
|
|
||||||
vgg_pretrained_features = vision.models.vgg16(
|
|
||||||
weights="VGG16_Weights.IMAGENET1K_V1"
|
|
||||||
).features
|
|
||||||
self.stage1 = torch.nn.Sequential()
|
|
||||||
self.stage2 = torch.nn.Sequential()
|
|
||||||
self.stage3 = torch.nn.Sequential()
|
|
||||||
self.stage4 = torch.nn.Sequential()
|
|
||||||
self.stage5 = torch.nn.Sequential()
|
|
||||||
for x in range(0, 4):
|
|
||||||
self.stage1.add_module(str(x), vgg_pretrained_features[x])
|
|
||||||
self.stage2.add_module(str(4), L2pooling(channels=64))
|
|
||||||
for x in range(5, 9):
|
|
||||||
self.stage2.add_module(str(x), vgg_pretrained_features[x])
|
|
||||||
self.stage3.add_module(str(9), L2pooling(channels=128))
|
|
||||||
for x in range(10, 16):
|
|
||||||
self.stage3.add_module(str(x), vgg_pretrained_features[x])
|
|
||||||
self.stage4.add_module(str(16), L2pooling(channels=256))
|
|
||||||
for x in range(17, 23):
|
|
||||||
self.stage4.add_module(str(x), vgg_pretrained_features[x])
|
|
||||||
self.stage5.add_module(str(23), L2pooling(channels=512))
|
|
||||||
for x in range(24, 30):
|
|
||||||
self.stage5.add_module(str(x), vgg_pretrained_features[x])
|
|
||||||
|
|
||||||
self.register_buffer(
|
|
||||||
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.chns = [3, 64, 128, 256, 512, 512]
|
|
||||||
self.register_buffer(
|
|
||||||
"alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))
|
|
||||||
)
|
|
||||||
self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
|
|
||||||
self.alpha.data.normal_(0.1, 0.01)
|
|
||||||
self.beta.data.normal_(0.1, 0.01)
|
|
||||||
weights = torch.load("test/DISTS_weights.pt")
|
|
||||||
self.alpha.data = weights["alpha"]
|
|
||||||
self.beta.data = weights["beta"]
|
|
||||||
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward_once(self, x):
|
|
||||||
h = (x - self.mean) / self.std
|
|
||||||
h = self.stage1(h)
|
|
||||||
h_relu1_2 = h
|
|
||||||
h = self.stage2(h)
|
|
||||||
h_relu2_2 = h
|
|
||||||
h = self.stage3(h)
|
|
||||||
h_relu3_3 = h
|
|
||||||
h = self.stage4(h)
|
|
||||||
h_relu4_3 = h
|
|
||||||
h = self.stage5(h)
|
|
||||||
h_relu5_3 = h
|
|
||||||
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
|
|
||||||
|
|
||||||
def forward(self, x, y, require_grad=False, batch_average=False):
|
|
||||||
if require_grad:
|
|
||||||
feats0 = self.forward_once(x)
|
|
||||||
feats1 = self.forward_once(y)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
feats0 = self.forward_once(x)
|
|
||||||
feats1 = self.forward_once(y)
|
|
||||||
dist1 = 0
|
|
||||||
dist2 = 0
|
|
||||||
c1 = 1e-6
|
|
||||||
c2 = 1e-6
|
|
||||||
w_sum = self.alpha.sum() + self.beta.sum()
|
|
||||||
alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
|
|
||||||
beta = torch.split(self.beta / w_sum, self.chns, dim=1)
|
|
||||||
|
|
||||||
for k in range(len(self.chns)):
|
|
||||||
x_mean = feats0[k].mean([2, 3], keepdim=True)
|
|
||||||
y_mean = feats1[k].mean([2, 3], keepdim=True)
|
|
||||||
S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
|
|
||||||
dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
|
|
||||||
|
|
||||||
x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
|
|
||||||
y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
|
|
||||||
xy_cov = (feats0[k] * feats1[k]).mean(
|
|
||||||
[2, 3], keepdim=True
|
|
||||||
) - x_mean * y_mean
|
|
||||||
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
|
|
||||||
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
|
|
||||||
|
|
||||||
score = 1 - (dist1 + dist2).squeeze()
|
|
||||||
|
|
||||||
if batch_average:
|
|
||||||
return score.mean()
|
|
||||||
else:
|
|
||||||
return score
|
|
||||||
|
|
||||||
class ToBinary(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x):#, noise_scale):
|
|
||||||
# if noise_scale > 0:
|
|
||||||
# noise_min = 0.5 - noise_scale / 2
|
|
||||||
# noise_max = 0.5 + noise_scale / 2
|
|
||||||
# return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
|
|
||||||
# else:
|
|
||||||
return torch.floor(x + 0.5) # no need for noise when we have plenty of data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return grad_output.clone()#, None
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
class R_ENCODER(MyModule):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
dd = 8
|
|
||||||
self.Bxx = nn.BatchNorm2d(dd*64)
|
|
||||||
|
|
||||||
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
|
|
||||||
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
|
||||||
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B00 = nn.BatchNorm2d(dd*4)
|
|
||||||
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B10 = nn.BatchNorm2d(dd*16)
|
|
||||||
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B20 = nn.BatchNorm2d(dd*64)
|
|
||||||
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
# self.B21 = nn.BatchNorm2d(dd*64)
|
|
||||||
# self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
# self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
# self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
# self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def forward(self, img):
|
|
||||||
ACT = F.mish
|
|
||||||
|
|
||||||
x = self.CIN(img)
|
|
||||||
xx = self.Bxx(F.pixel_unshuffle(x, 8))
|
|
||||||
x = x + self.Cx1(ACT(self.Cx0(x)))
|
|
||||||
|
|
||||||
x = F.pixel_unshuffle(x, 2)
|
|
||||||
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
|
||||||
x = x + self.C03(ACT(self.C02(x)))
|
|
||||||
|
|
||||||
x = F.pixel_unshuffle(x, 2)
|
|
||||||
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
|
|
||||||
x = x + self.C13(ACT(self.C12(x)))
|
|
||||||
|
|
||||||
x = F.pixel_unshuffle(x, 2)
|
|
||||||
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
|
|
||||||
x = x + self.C23(ACT(self.C22(x)))
|
|
||||||
# x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
|
|
||||||
# x = x + self.C27(ACT(self.C26(x)))
|
|
||||||
|
|
||||||
x = self.COUT(x + xx)
|
|
||||||
return torch.sigmoid(x)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
class R_DECODER(MyModule):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
dd = 8
|
|
||||||
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B00 = nn.BatchNorm2d(dd*64)
|
|
||||||
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
# self.B01 = nn.BatchNorm2d(dd*64)
|
|
||||||
# self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
# self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
# self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
|
|
||||||
# self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B10 = nn.BatchNorm2d(dd*16)
|
|
||||||
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
|
|
||||||
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.B20 = nn.BatchNorm2d(dd*4)
|
|
||||||
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
|
|
||||||
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
|
|
||||||
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
|
|
||||||
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def forward(self, code):
|
|
||||||
ACT = F.mish
|
|
||||||
x = self.CIN(code)
|
|
||||||
|
|
||||||
x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
|
|
||||||
x = x + self.C03(ACT(self.C02(x)))
|
|
||||||
# x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
|
|
||||||
# x = x + self.C07(ACT(self.C06(x)))
|
|
||||||
x = F.pixel_shuffle(x, 2)
|
|
||||||
|
|
||||||
x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
|
|
||||||
x = x + self.C13(ACT(self.C12(x)))
|
|
||||||
x = F.pixel_shuffle(x, 2)
|
|
||||||
|
|
||||||
x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
|
|
||||||
x = x + self.C23(ACT(self.C22(x)))
|
|
||||||
x = F.pixel_shuffle(x, 2)
|
|
||||||
|
|
||||||
x = x + self.Cx1(ACT(self.Cx0(x)))
|
|
||||||
x = self.COUT(x)
|
|
||||||
|
|
||||||
return torch.sigmoid(x)
|
|
||||||
|
|
||||||
########################################################################################################`
|
|
||||||
|
|
||||||
def cosine_loss(x, y):
|
|
||||||
x = F.normalize(x, dim=-1)
|
|
||||||
y = F.normalize(y, dim=-1)
|
|
||||||
return 1 - torch.einsum('ij,ij->i',[x,y])
|
|
||||||
|
|
||||||
class RWKV_IMG(pl.LightningModule):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
self.encoder = R_ENCODER(args)
|
|
||||||
self.decoder = R_DECODER(args)
|
|
||||||
|
|
||||||
self.clip_model = None
|
|
||||||
clip_name = args.my_img_clip
|
|
||||||
if clip_name == 'B32':
|
|
||||||
clip_name = 'ViT-B/32'
|
|
||||||
elif clip_name == 'B16':
|
|
||||||
clip_name = 'ViT-B/16'
|
|
||||||
elif clip_name == 'L14':
|
|
||||||
clip_name = 'ViT-L/14'
|
|
||||||
elif clip_name == 'OB32':
|
|
||||||
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
|
||||||
self.clip_model = CLIPModel.from_pretrained(clip_name)
|
|
||||||
self.clip_model.encode_image = self.clip_model.get_image_features
|
|
||||||
if self.clip_model == None:
|
|
||||||
self.clip_model, _ = clip.load(clip_name, jit = True)
|
|
||||||
self.register_buffer(
|
|
||||||
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
for n, p in self.named_parameters():
|
|
||||||
if 'clip_model' in n:
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
self.loss_dists = DISTS()
|
|
||||||
# self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
args = self.args
|
|
||||||
optim_groups = [
|
|
||||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
|
||||||
]
|
|
||||||
if self.deepspeed_offload:
|
|
||||||
return DeepSpeedCPUAdam(
|
|
||||||
optim_groups,
|
|
||||||
lr=self.args.lr_init,
|
|
||||||
betas=self.args.betas,
|
|
||||||
eps=self.args.adam_eps,
|
|
||||||
bias_correction=True,
|
|
||||||
adamw_mode=False,
|
|
||||||
weight_decay=0,
|
|
||||||
amsgrad=False,
|
|
||||||
)
|
|
||||||
return FusedAdam(
|
|
||||||
optim_groups,
|
|
||||||
lr=self.args.lr_init,
|
|
||||||
betas=self.args.betas,
|
|
||||||
eps=self.args.adam_eps,
|
|
||||||
bias_correction=True,
|
|
||||||
adam_w_mode=False,
|
|
||||||
weight_decay=0,
|
|
||||||
amsgrad=False,
|
|
||||||
)
|
|
||||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def deepspeed_offload(self) -> bool:
|
|
||||||
strategy = self.trainer.strategy
|
|
||||||
if isinstance(strategy, DeepSpeedStrategy):
|
|
||||||
config = strategy.config["zero_optimization"]
|
|
||||||
return config.get("offload_optimizer") or config.get("offload_param")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def forward(self, img):
|
|
||||||
z = self.encoder(img)
|
|
||||||
z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
|
|
||||||
out = self.decoder(z)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
args = self.args
|
|
||||||
img, txt = batch
|
|
||||||
out = self(img)
|
|
||||||
if self.trainer.is_global_zero:
|
|
||||||
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
|
|
||||||
img_dir = f"test/image_model/{args.run_name}"
|
|
||||||
if not os.path.exists(img_dir):
|
|
||||||
os.makedirs(img_dir)
|
|
||||||
vision.utils.save_image(
|
|
||||||
img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
|
|
||||||
)
|
|
||||||
vision.utils.save_image(
|
|
||||||
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# loss_ssim = 1 - self.loss_ssim(out, img)
|
|
||||||
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
|
|
||||||
|
|
||||||
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
|
|
||||||
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
|
|
||||||
loss_clip = torch.mean(cosine_loss(iii, ooo))
|
|
||||||
|
|
||||||
if args.my_img_l1_scale > 0:
|
|
||||||
loss_l1 = F.l1_loss(out, img)
|
|
||||||
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
|
|
||||||
else:
|
|
||||||
return loss_dists + loss_clip * args.my_img_clip_scale
|
|
||||||
|
|
||||||
def training_step_end(self, batch_parts):
|
|
||||||
all = self.all_gather(batch_parts)
|
|
||||||
if self.trainer.is_global_zero:
|
|
||||||
self.trainer.my_loss_all = all
|
|
||||||
|
|
||||||
def generate_init_weight(self):
|
|
||||||
print(
|
|
||||||
f"""
|
|
||||||
############################################################################
|
|
||||||
#
|
|
||||||
# Init model weight (slow for large models)...
|
|
||||||
#
|
|
||||||
############################################################################
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
m = {}
|
|
||||||
for n in self.state_dict():
|
|
||||||
scale = 1
|
|
||||||
p = self.state_dict()[n]
|
|
||||||
shape = p.shape
|
|
||||||
ss = n.split('.')
|
|
||||||
|
|
||||||
# if ss[0] in ['encoder', 'decoder']:
|
|
||||||
# if ss[2] == 'bias':
|
|
||||||
# scale = 0
|
|
||||||
# # elif n == 'encoder.CIN.weight':
|
|
||||||
# # nn.init.dirac_(p)
|
|
||||||
# else:
|
|
||||||
# try:
|
|
||||||
# if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
|
|
||||||
# scale = 0
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
# m[n] = p * scale
|
|
||||||
|
|
||||||
m[n] = p
|
|
||||||
|
|
||||||
m[n] = m[n].cpu()
|
|
||||||
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
||||||
m[n] = m[n].half()
|
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
||||||
m[n] = m[n].bfloat16()
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return m
|
|
||||||
@ -1,237 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import types
|
|
||||||
import torch
|
|
||||||
import math, os, gc
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torch.nn as nn
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
MyModule = nn.Module
|
|
||||||
def __nop(ob):
|
|
||||||
return ob
|
|
||||||
MyFunction = __nop
|
|
||||||
|
|
||||||
# # try torchdynamo
|
|
||||||
# import torchdynamo
|
|
||||||
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
|
|
||||||
|
|
||||||
# try torch jit --> faster for fp32, slower for fp16 (why?)
|
|
||||||
if os.environ["RWKV_JIT_ON"] == "1":
|
|
||||||
MyModule = torch.jit.ScriptModule
|
|
||||||
MyFunction = torch.jit.script_method
|
|
||||||
|
|
||||||
RWKV_HEAD_QK_DIM = 0
|
|
||||||
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
|
|
||||||
|
|
||||||
DEBUG_TIME = False # True False - show trained time-coeffs
|
|
||||||
|
|
||||||
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
|
|
||||||
class RWKV_RNN(MyModule):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.args = args
|
|
||||||
self.FLOAT_MODE = args.FLOAT_MODE
|
|
||||||
self.RUN_DEVICE = args.RUN_DEVICE
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
|
||||||
# refine weights and send to correct device
|
|
||||||
keys = list(w.keys())
|
|
||||||
if 'pos_emb_x' in keys:
|
|
||||||
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
|
|
||||||
keys = list(w.keys())
|
|
||||||
print_need_newline = False
|
|
||||||
for x in keys:
|
|
||||||
block_id = 0
|
|
||||||
if 'blocks.' in x:
|
|
||||||
block_id = int(x.split('.')[1])
|
|
||||||
if 'att.output.weight' in x:
|
|
||||||
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
|
|
||||||
if 'ffn.value.weight' in x:
|
|
||||||
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
|
|
||||||
|
|
||||||
if '.time_' in x:
|
|
||||||
w[x] = w[x].squeeze()
|
|
||||||
if DEBUG_TIME:
|
|
||||||
print(x, w[x].numpy())
|
|
||||||
if '.time_decay' in x:
|
|
||||||
w[x] = w[x].float()
|
|
||||||
w[x] = -torch.exp(w[x])
|
|
||||||
elif '.time_first' in x:
|
|
||||||
w[x] = w[x].float()
|
|
||||||
else:
|
|
||||||
if self.FLOAT_MODE == "fp32":
|
|
||||||
w[x] = w[x].float()
|
|
||||||
elif self.FLOAT_MODE == "bf16":
|
|
||||||
w[x] = w[x].bfloat16()
|
|
||||||
elif self.FLOAT_MODE == "fp16":
|
|
||||||
w[x] = w[x].half()
|
|
||||||
|
|
||||||
w[x].requires_grad = False
|
|
||||||
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
|
|
||||||
w[x] = w[x].cuda()
|
|
||||||
|
|
||||||
if ('blocks.' not in x) or ('blocks.0.' in x):
|
|
||||||
if print_need_newline:
|
|
||||||
print('\n', end = '')
|
|
||||||
print_need_newline = False
|
|
||||||
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
|
|
||||||
else:
|
|
||||||
print_need_newline = True
|
|
||||||
print('.', end = '', flush = True)
|
|
||||||
|
|
||||||
# store weights in self.w
|
|
||||||
keys = list(w.keys())
|
|
||||||
self.w = types.SimpleNamespace()
|
|
||||||
for x in keys:
|
|
||||||
xx = x.split('.')
|
|
||||||
here = self.w
|
|
||||||
for i in range(len(xx)):
|
|
||||||
if xx[i].isdigit():
|
|
||||||
ii = int(xx[i])
|
|
||||||
if ii not in here:
|
|
||||||
here[ii] = types.SimpleNamespace()
|
|
||||||
here = here[ii]
|
|
||||||
else:
|
|
||||||
if i == len(xx) - 1:
|
|
||||||
setattr(here, xx[i], w[x])
|
|
||||||
elif not hasattr(here, xx[i]):
|
|
||||||
if xx[i+1].isdigit():
|
|
||||||
setattr(here, xx[i], {})
|
|
||||||
else:
|
|
||||||
setattr(here, xx[i], types.SimpleNamespace())
|
|
||||||
here = getattr(here, xx[i])
|
|
||||||
|
|
||||||
self.eval()
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def LN(self, x, w):
|
|
||||||
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
|
|
||||||
|
|
||||||
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
|
|
||||||
if self.FLOAT_MODE == "bf16":
|
|
||||||
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
|
|
||||||
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
|
|
||||||
state[5*i+0] = x.float()
|
|
||||||
elif self.FLOAT_MODE == "fp16":
|
|
||||||
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
|
|
||||||
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
|
|
||||||
state[5*i+0] = x.float()
|
|
||||||
else:
|
|
||||||
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
|
|
||||||
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
|
|
||||||
state[5*i+0] = x
|
|
||||||
|
|
||||||
r = torch.sigmoid(rw @ xr)
|
|
||||||
k = torch.square(torch.relu(kw @ xk))
|
|
||||||
kv = vw @ k
|
|
||||||
|
|
||||||
return r * kv
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
|
|
||||||
if self.FLOAT_MODE == "bf16":
|
|
||||||
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
|
|
||||||
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
|
|
||||||
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
|
|
||||||
state[5*i+1] = x.float()
|
|
||||||
elif self.FLOAT_MODE == "fp16":
|
|
||||||
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
|
|
||||||
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
|
|
||||||
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
|
|
||||||
state[5*i+1] = x.float()
|
|
||||||
else:
|
|
||||||
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
|
|
||||||
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
|
|
||||||
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
|
|
||||||
state[5*i+1] = x
|
|
||||||
|
|
||||||
r = torch.sigmoid(rw @ xr)
|
|
||||||
k = kw @ xk
|
|
||||||
v = vw @ xv
|
|
||||||
|
|
||||||
if '16' in self.FLOAT_MODE:
|
|
||||||
kk = k.float()
|
|
||||||
vv = v.float()
|
|
||||||
else:
|
|
||||||
kk = k
|
|
||||||
vv = v
|
|
||||||
aa = state[5*i+2]
|
|
||||||
bb = state[5*i+3]
|
|
||||||
pp = state[5*i+4]
|
|
||||||
ww = time_first + kk
|
|
||||||
p = torch.maximum(pp, ww)
|
|
||||||
e1 = torch.exp(pp - p)
|
|
||||||
e2 = torch.exp(ww - p)
|
|
||||||
a = e1 * aa + e2 * vv
|
|
||||||
b = e1 * bb + e2
|
|
||||||
ww = pp + time_decay
|
|
||||||
p = torch.maximum(ww, kk)
|
|
||||||
e1 = torch.exp(ww - p)
|
|
||||||
e2 = torch.exp(kk - p)
|
|
||||||
state[5*i+2] = e1 * aa + e2 * vv
|
|
||||||
state[5*i+3] = e1 * bb + e2
|
|
||||||
state[5*i+4] = p
|
|
||||||
if self.FLOAT_MODE == "bf16":
|
|
||||||
wkv = (a / b).type(torch.bfloat16)
|
|
||||||
elif self.FLOAT_MODE == "fp16":
|
|
||||||
wkv = (a / b).half()
|
|
||||||
else:
|
|
||||||
wkv = a / b
|
|
||||||
|
|
||||||
return ow @ (r * wkv)
|
|
||||||
|
|
||||||
def forward(self, ctx, state, preprocess_only = False):
|
|
||||||
with torch.no_grad():
|
|
||||||
w = self.w
|
|
||||||
args = self.args
|
|
||||||
|
|
||||||
x = w.emb.weight[ctx[-1]]
|
|
||||||
if self.RUN_DEVICE == 'cuda':
|
|
||||||
x = x.cuda()
|
|
||||||
try:
|
|
||||||
pos_emb = w.pos_emb[len(ctx)-1]
|
|
||||||
x = x + pos_emb
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if state == None:
|
|
||||||
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
|
|
||||||
for i in range(args.n_layer):
|
|
||||||
state[5*i+4] -= 1e30
|
|
||||||
|
|
||||||
for i in range(args.n_layer):
|
|
||||||
if i == 0:
|
|
||||||
x = self.LN(x, w.blocks[i].ln0)
|
|
||||||
|
|
||||||
ww = w.blocks[i].att
|
|
||||||
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
|
|
||||||
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
|
|
||||||
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
|
|
||||||
|
|
||||||
ww = w.blocks[i].ffn
|
|
||||||
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
|
|
||||||
ww.time_mix_k, ww.time_mix_r,
|
|
||||||
ww.key.weight, ww.value.weight, ww.receptance.weight)
|
|
||||||
|
|
||||||
if (i+1) % RWKV_RESCALE_LAYER == 0:
|
|
||||||
x = x / 2
|
|
||||||
|
|
||||||
if preprocess_only:
|
|
||||||
return state
|
|
||||||
|
|
||||||
x = self.LN(x, w.ln_out)
|
|
||||||
x = w.head.weight @ x
|
|
||||||
|
|
||||||
return x.float(), state
|
|
||||||
@ -1,190 +0,0 @@
|
|||||||
import os, math, time, datetime, subprocess
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
||||||
|
|
||||||
def my_save(dd, ff):
|
|
||||||
if '14b-run1' not in ff:
|
|
||||||
torch.save(dd, ff)
|
|
||||||
else:
|
|
||||||
fn = ff.split('/')[-1]
|
|
||||||
fff = '/dev/shm/' + fn
|
|
||||||
torch.save(dd, fff)
|
|
||||||
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
|
||||||
|
|
||||||
class train_callback(pl.Callback):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
|
||||||
args = self.args
|
|
||||||
# if args.cuda_cleanup > 0:
|
|
||||||
# torch.cuda.empty_cache()
|
|
||||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
|
||||||
|
|
||||||
# LR schedule
|
|
||||||
w_step = args.warmup_steps
|
|
||||||
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
|
||||||
lr = args.lr_init
|
|
||||||
else:
|
|
||||||
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
|
||||||
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
|
||||||
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
|
||||||
progress = min(1, max(0, progress))
|
|
||||||
|
|
||||||
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
|
||||||
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
|
||||||
else: # exp decay
|
|
||||||
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
|
||||||
|
|
||||||
if trainer.global_step < w_step:
|
|
||||||
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
|
||||||
# if trainer.is_global_zero:
|
|
||||||
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
|
||||||
|
|
||||||
for param_group in trainer.optimizers[0].param_groups:
|
|
||||||
if args.layerwise_lr > 0:
|
|
||||||
param_group["lr"] = lr * param_group["my_lr_scale"]
|
|
||||||
# print(param_group["lr"], param_group["my_lr_scale"])
|
|
||||||
else:
|
|
||||||
param_group["lr"] = lr
|
|
||||||
|
|
||||||
trainer.my_lr = lr
|
|
||||||
# rank_zero_info(f"{real_step} {lr}")
|
|
||||||
|
|
||||||
if trainer.global_step == 0:
|
|
||||||
if trainer.is_global_zero: # logging
|
|
||||||
trainer.my_loss_sum = 0
|
|
||||||
trainer.my_loss_count = 0
|
|
||||||
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
|
||||||
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
|
|
||||||
try:
|
|
||||||
print(f"\n{trainer.strategy.config}\n")
|
|
||||||
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
trainer.my_log.flush()
|
|
||||||
if len(args.wandb) > 0:
|
|
||||||
print("Login to wandb...")
|
|
||||||
import wandb
|
|
||||||
wandb.init(
|
|
||||||
project=args.wandb,
|
|
||||||
name=args.run_name + " " + args.my_timestamp,
|
|
||||||
config=args,
|
|
||||||
save_code=False,
|
|
||||||
)
|
|
||||||
trainer.my_wandb = wandb
|
|
||||||
|
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
||||||
args = self.args
|
|
||||||
if trainer.is_global_zero: # logging
|
|
||||||
t_now = time.time_ns()
|
|
||||||
token_per_step = args.ctx_len * args.real_bsz
|
|
||||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
|
||||||
kt_s = 0
|
|
||||||
try:
|
|
||||||
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
|
||||||
kt_s = token_per_step / t_cost / 1000
|
|
||||||
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
|
||||||
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
trainer.my_time_ns = t_now
|
|
||||||
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
|
||||||
trainer.my_loss_sum += trainer.my_loss
|
|
||||||
trainer.my_loss_count += 1
|
|
||||||
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
|
||||||
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
|
||||||
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
|
||||||
# self.log("s", real_step, prog_bar=True, on_step=True)
|
|
||||||
|
|
||||||
if len(args.wandb) > 0:
|
|
||||||
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
|
|
||||||
if kt_s > 0:
|
|
||||||
lll["kt/s"] = kt_s
|
|
||||||
trainer.my_wandb.log(lll, step=int(real_step))
|
|
||||||
if args.magic_prime > 0:
|
|
||||||
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
|
||||||
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
|
|
||||||
to_save_dict = pl_module.state_dict()
|
|
||||||
my_save(
|
|
||||||
to_save_dict,
|
|
||||||
f"{args.proj_dir}/rwkv-final.pth",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
|
||||||
args = self.args
|
|
||||||
dataset = trainer.train_dataloader.dataset.datasets
|
|
||||||
assert "MyDataset" in str(dataset)
|
|
||||||
dataset.global_rank = trainer.global_rank
|
|
||||||
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
|
||||||
dataset.world_size = trainer.world_size
|
|
||||||
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
|
||||||
args = self.args
|
|
||||||
if trainer.is_global_zero: # logging & save state_dict
|
|
||||||
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
|
|
||||||
if args.data_type == 'wds_img':
|
|
||||||
raw_dict = pl_module.state_dict()
|
|
||||||
to_save_dict = {}
|
|
||||||
for k in raw_dict:
|
|
||||||
if k.startswith('encoder.') or k.startswith('decoder.'):
|
|
||||||
to_save_dict[k] = raw_dict[k]
|
|
||||||
else:
|
|
||||||
to_save_dict = pl_module.state_dict()
|
|
||||||
try:
|
|
||||||
my_save(
|
|
||||||
to_save_dict,
|
|
||||||
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print('Error\n\n', e, '\n\n')
|
|
||||||
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
|
||||||
trainer.my_log.flush()
|
|
||||||
|
|
||||||
trainer.my_loss_sum = 0
|
|
||||||
trainer.my_loss_count = 0
|
|
||||||
|
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def generate_init_weight(model, init_weight_name):
|
|
||||||
mm = model.generate_init_weight()
|
|
||||||
|
|
||||||
if model.args.my_pile_stage == 1:
|
|
||||||
if len(model.args.load_model) > 0:
|
|
||||||
print(f"Combine weights from {model.args.load_model}...")
|
|
||||||
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
|
||||||
for k in load_dict:
|
|
||||||
assert k in mm
|
|
||||||
src = load_dict[k]
|
|
||||||
try:
|
|
||||||
mm[k] = src.reshape(mm[k].shape)
|
|
||||||
except:
|
|
||||||
tmp = mm[k].squeeze().clone()
|
|
||||||
print(k, src.shape, '-->', mm[k].shape)
|
|
||||||
ss = src.shape[0]
|
|
||||||
dd = tmp.shape[0]
|
|
||||||
for i in range(dd):
|
|
||||||
pos = i / dd * ss
|
|
||||||
if pos >= ss - 1:
|
|
||||||
tmp[i] = src[ss-1]
|
|
||||||
else:
|
|
||||||
p0 = int(math.floor(pos))
|
|
||||||
ii = pos - p0
|
|
||||||
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
|
|
||||||
mm[k] = tmp.reshape(mm[k].shape)
|
|
||||||
sss = src.squeeze().float().cpu().numpy()
|
|
||||||
print(sss[:10], '...', sss[-10:])
|
|
||||||
mmm = mm[k].squeeze().float().cpu().numpy()
|
|
||||||
print(mmm[:10], '...', mmm[-10:])
|
|
||||||
|
|
||||||
print(f"Save to {init_weight_name}...")
|
|
||||||
torch.save(mm, init_weight_name)
|
|
||||||
|
|
||||||
if model.args.my_pile_stage == 1:
|
|
||||||
print("Done. Now go for stage 2.")
|
|
||||||
exit(0)
|
|
||||||
@ -1,130 +0,0 @@
|
|||||||
import json, time, random, os
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
time_slot = {}
|
|
||||||
time_ref = time.time_ns()
|
|
||||||
|
|
||||||
def record_time(name):
|
|
||||||
if name not in time_slot:
|
|
||||||
time_slot[name] = 1e20
|
|
||||||
tt = (time.time_ns() - time_ref) / 1e9
|
|
||||||
if tt < time_slot[name]:
|
|
||||||
time_slot[name] = tt
|
|
||||||
|
|
||||||
class TOKENIZER():
|
|
||||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
|
||||||
if 'list' in str(type(WORD_NAME)):
|
|
||||||
self.charMode = False
|
|
||||||
if WORD_NAME[0] == WORD_NAME[1]:
|
|
||||||
from transformers import PreTrainedTokenizerFast
|
|
||||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
|
||||||
else:
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
|
||||||
self.vocab_size = len(self.tokenizer)
|
|
||||||
else:
|
|
||||||
self.charMode = True
|
|
||||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
|
||||||
self.word_table = json.load(result_file)
|
|
||||||
|
|
||||||
self.vocab_size = len(self.word_table)
|
|
||||||
|
|
||||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
|
||||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
|
||||||
|
|
||||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
|
||||||
|
|
||||||
def refine_context(self, context):
|
|
||||||
context = context.strip().split('\n')
|
|
||||||
for c in range(len(context)):
|
|
||||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
|
||||||
context = list(filter(lambda c: c != '', context))
|
|
||||||
context = '\n' + ('\n'.join(context)).strip()
|
|
||||||
if context == '':
|
|
||||||
context = '\n'
|
|
||||||
return context
|
|
||||||
|
|
||||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
|
||||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
|
||||||
lastChar = int(x[-1])
|
|
||||||
|
|
||||||
probs = F.softmax(out, dim=-1)
|
|
||||||
|
|
||||||
if self.charMode:
|
|
||||||
if self.itos[lastChar] == '\n':
|
|
||||||
top_p = top_p_newline
|
|
||||||
else:
|
|
||||||
top_p = top_p_usual
|
|
||||||
else:
|
|
||||||
top_p = top_p_usual
|
|
||||||
|
|
||||||
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
|
||||||
probs = probs.numpy()
|
|
||||||
sorted_probs = np.sort(probs)[::-1]
|
|
||||||
cumulative_probs = np.cumsum(sorted_probs)
|
|
||||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
|
||||||
probs[probs < cutoff] = 0
|
|
||||||
if temperature != 1.0:
|
|
||||||
probs = probs.pow(1.0 / temperature)
|
|
||||||
probs = probs / np.sum(probs)
|
|
||||||
out = np.random.choice(a=len(probs), p=probs)
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
sorted_probs = torch.sort(probs, descending=True)[0]
|
|
||||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
|
||||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
|
||||||
probs[probs < cutoff] = 0
|
|
||||||
if temperature != 1.0:
|
|
||||||
probs = probs.pow(1.0 / temperature)
|
|
||||||
out = torch.multinomial(probs, num_samples=1)[0]
|
|
||||||
return out
|
|
||||||
|
|
||||||
def MaybeIsPrime(number):
|
|
||||||
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def FermatPrimalityTest(number):
|
|
||||||
if number > 1:
|
|
||||||
for time in range(3):
|
|
||||||
randomNumber = random.randint(2, number) - 1
|
|
||||||
if pow(randomNumber, number - 1, number) != 1:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def MillerRabinPrimalityTest(number):
|
|
||||||
if number == 2:
|
|
||||||
return True
|
|
||||||
elif number == 1 or number % 2 == 0:
|
|
||||||
return False
|
|
||||||
oddPartOfNumber = number - 1
|
|
||||||
timesTwoDividNumber = 0
|
|
||||||
while oddPartOfNumber % 2 == 0:
|
|
||||||
oddPartOfNumber = oddPartOfNumber // 2
|
|
||||||
timesTwoDividNumber = timesTwoDividNumber + 1
|
|
||||||
|
|
||||||
for time in range(3):
|
|
||||||
while True:
|
|
||||||
randomNumber = random.randint(2, number) - 1
|
|
||||||
if randomNumber != 0 and randomNumber != 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
|
||||||
|
|
||||||
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
|
||||||
iterationNumber = 1
|
|
||||||
|
|
||||||
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
|
|
||||||
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
|
||||||
iterationNumber = iterationNumber + 1
|
|
||||||
if randomNumberWithPower != (number - 1):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
@ -1,349 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from pytorch_lightning import Trainer
|
|
||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
||||||
|
|
||||||
rank_zero_info("########## work in progress ##########")
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
#
|
|
||||||
# example: train a simple L12-D768 RWKV on dummy data
|
|
||||||
#
|
|
||||||
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
|
||||||
# --data_file "" --data_type "dummy" --vocab_size 0 \
|
|
||||||
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
|
|
||||||
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
|
|
||||||
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
|
||||||
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
|
||||||
|
|
||||||
# example: train a simple L6-D512 RWKV from scratch on enwik8
|
|
||||||
#
|
|
||||||
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
|
||||||
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
|
|
||||||
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
|
|
||||||
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
|
|
||||||
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
|
||||||
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
|
||||||
|
|
||||||
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
|
|
||||||
#
|
|
||||||
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
|
||||||
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
|
||||||
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
|
|
||||||
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
|
||||||
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
|
||||||
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
|
|
||||||
|
|
||||||
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
|
|
||||||
#
|
|
||||||
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
|
||||||
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
|
||||||
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
|
|
||||||
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
|
||||||
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
|
||||||
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
|
|
||||||
|
|
||||||
parser = ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
|
||||||
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
|
|
||||||
parser.add_argument("--proj_dir", default="out", type=str)
|
|
||||||
parser.add_argument("--random_seed", default="-1", type=int)
|
|
||||||
|
|
||||||
parser.add_argument("--data_file", default="", type=str)
|
|
||||||
parser.add_argument("--data_type", default="utf-8", type=str)
|
|
||||||
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
|
||||||
|
|
||||||
parser.add_argument("--ctx_len", default=1024, type=int)
|
|
||||||
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
|
|
||||||
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
|
|
||||||
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
|
|
||||||
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
|
|
||||||
|
|
||||||
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
|
||||||
parser.add_argument("--n_layer", default=6, type=int)
|
|
||||||
parser.add_argument("--n_embd", default=512, type=int)
|
|
||||||
parser.add_argument("--dim_att", default=0, type=int)
|
|
||||||
parser.add_argument("--dim_ffn", default=0, type=int)
|
|
||||||
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
|
|
||||||
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
|
||||||
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
|
||||||
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
|
|
||||||
|
|
||||||
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
|
||||||
parser.add_argument("--lr_final", default=1e-5, type=float)
|
|
||||||
parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
|
|
||||||
parser.add_argument("--beta1", default=0.9, type=float)
|
|
||||||
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
|
||||||
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
|
||||||
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
|
||||||
|
|
||||||
parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
|
|
||||||
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
|
||||||
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
|
|
||||||
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
|
||||||
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
|
||||||
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
|
||||||
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
|
||||||
|
|
||||||
parser.add_argument("--my_img_version", default=0, type=str)
|
|
||||||
parser.add_argument("--my_img_size", default=0, type=int)
|
|
||||||
parser.add_argument("--my_img_bit", default=0, type=int)
|
|
||||||
parser.add_argument("--my_img_clip", default='x', type=str)
|
|
||||||
parser.add_argument("--my_img_clip_scale", default=1, type=float)
|
|
||||||
parser.add_argument("--my_img_l1_scale", default=0, type=float)
|
|
||||||
parser.add_argument("--my_img_encoder", default='x', type=str)
|
|
||||||
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
|
|
||||||
parser.add_argument("--my_sample_len", default=0, type=int)
|
|
||||||
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
|
||||||
parser.add_argument("--my_att_shift", default=1, type=int)
|
|
||||||
parser.add_argument("--my_pos_emb", default=0, type=int)
|
|
||||||
parser.add_argument("--load_partial", default=0, type=int)
|
|
||||||
parser.add_argument("--magic_prime", default=0, type=int)
|
|
||||||
parser.add_argument("--my_qa_mask", default=0, type=int)
|
|
||||||
parser.add_argument("--my_testing", default='', type=str)
|
|
||||||
|
|
||||||
parser = Trainer.add_argparse_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
import os, warnings, math, datetime, sys, time, importlib
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
if "deepspeed" in args.strategy:
|
|
||||||
import deepspeed
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
|
|
||||||
if args.random_seed >= 0:
|
|
||||||
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
|
|
||||||
seed_everything(args.random_seed)
|
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
|
||||||
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
|
||||||
# os.environ["WDS_SHOW_SEED"] = "1"
|
|
||||||
|
|
||||||
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
|
||||||
args.enable_checkpointing = False
|
|
||||||
args.replace_sampler_ddp = False
|
|
||||||
args.logger = False
|
|
||||||
args.gradient_clip_val = 1.0
|
|
||||||
args.num_sanity_val_steps = 0
|
|
||||||
args.check_val_every_n_epoch = int(1e20)
|
|
||||||
args.log_every_n_steps = int(1e20)
|
|
||||||
args.max_epochs = -1 # continue forever
|
|
||||||
args.betas = (args.beta1, args.beta2)
|
|
||||||
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
|
||||||
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
|
||||||
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
|
||||||
if args.dim_att <= 0:
|
|
||||||
args.dim_att = args.n_embd
|
|
||||||
if args.dim_ffn <= 0:
|
|
||||||
args.dim_ffn = args.n_embd * 4
|
|
||||||
|
|
||||||
if args.data_type == "wds_img":
|
|
||||||
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
|
||||||
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
|
||||||
else:
|
|
||||||
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
|
||||||
if not os.path.exists(args.proj_dir):
|
|
||||||
os.makedirs(args.proj_dir)
|
|
||||||
|
|
||||||
if args.my_pile_stage > 0:
|
|
||||||
magic_prime_bak = args.magic_prime
|
|
||||||
|
|
||||||
if args.my_pile_version == 1:
|
|
||||||
if args.ctx_len == 1024:
|
|
||||||
args.magic_prime = 324331313
|
|
||||||
args.epoch_count = 8043
|
|
||||||
elif args.ctx_len == 2048:
|
|
||||||
args.magic_prime = 162165671
|
|
||||||
args.epoch_count = 4021
|
|
||||||
elif args.ctx_len == 4096:
|
|
||||||
args.magic_prime = 81082817
|
|
||||||
args.epoch_count = 2010
|
|
||||||
elif args.ctx_len == 8192:
|
|
||||||
args.magic_prime = 40541399
|
|
||||||
args.epoch_count = 1005
|
|
||||||
else:
|
|
||||||
if args.ctx_len == 1024:
|
|
||||||
args.magic_prime = 1694947181
|
|
||||||
args.epoch_count = 42036
|
|
||||||
elif args.ctx_len == 2048:
|
|
||||||
args.magic_prime = 847473509
|
|
||||||
args.epoch_count = 21017
|
|
||||||
elif args.ctx_len == 4096:
|
|
||||||
args.magic_prime = 423736637
|
|
||||||
args.epoch_count = 10508
|
|
||||||
elif args.ctx_len == 6144:
|
|
||||||
args.magic_prime = 282491051
|
|
||||||
args.epoch_count = 7005
|
|
||||||
elif args.ctx_len == 8192:
|
|
||||||
args.magic_prime = 211868243
|
|
||||||
args.epoch_count = 5253
|
|
||||||
if args.my_pile_shift < 0:
|
|
||||||
args.my_pile_shift = 0
|
|
||||||
|
|
||||||
if magic_prime_bak > 0:
|
|
||||||
args.magic_prime = magic_prime_bak
|
|
||||||
|
|
||||||
args.epoch_steps = 40320 // args.real_bsz
|
|
||||||
assert args.epoch_steps * args.real_bsz == 40320
|
|
||||||
if args.my_pile_stage == 2:
|
|
||||||
assert args.lr_final == args.lr_init
|
|
||||||
if args.my_pile_stage >= 2: # find latest saved model
|
|
||||||
list_p = []
|
|
||||||
for p in os.listdir(args.proj_dir):
|
|
||||||
if p.startswith("rwkv") and p.endswith(".pth"):
|
|
||||||
p = ((p.split("-"))[1].split("."))[0]
|
|
||||||
if p == "init":
|
|
||||||
p = -1
|
|
||||||
else:
|
|
||||||
p = int(p)
|
|
||||||
list_p += [p]
|
|
||||||
list_p.sort()
|
|
||||||
max_p = list_p[-1]
|
|
||||||
if len(list_p) > 1:
|
|
||||||
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
|
||||||
if max_p == -1:
|
|
||||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
||||||
else:
|
|
||||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
||||||
if args.warmup_steps < 0:
|
|
||||||
if args.my_pile_stage == 2:
|
|
||||||
args.warmup_steps = 10
|
|
||||||
else:
|
|
||||||
args.warmup_steps = 30
|
|
||||||
args.epoch_begin = max_p + 1
|
|
||||||
|
|
||||||
samples_per_epoch = args.epoch_steps * args.real_bsz
|
|
||||||
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
|
||||||
rank_zero_info(
|
|
||||||
f"""
|
|
||||||
############################################################################
|
|
||||||
#
|
|
||||||
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
|
||||||
#
|
|
||||||
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
|
||||||
#
|
|
||||||
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
|
|
||||||
#
|
|
||||||
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
|
||||||
#
|
|
||||||
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
|
||||||
#
|
|
||||||
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
|
||||||
#
|
|
||||||
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
|
||||||
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
|
|
||||||
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
|
|
||||||
#
|
|
||||||
############################################################################
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
rank_zero_info(str(vars(args)) + "\n")
|
|
||||||
|
|
||||||
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
|
|
||||||
|
|
||||||
if args.lr_final == 0 or args.lr_init == 0:
|
|
||||||
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
|
||||||
|
|
||||||
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
|
||||||
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
|
||||||
if args.precision == "fp32":
|
|
||||||
for i in range(10):
|
|
||||||
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
|
||||||
if args.precision == "fp16":
|
|
||||||
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
|
||||||
|
|
||||||
os.environ["RWKV_JIT_ON"] = "1"
|
|
||||||
if "deepspeed_stage_3" in args.strategy:
|
|
||||||
os.environ["RWKV_JIT_ON"] = "0"
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.enabled = True
|
|
||||||
if args.precision == "fp32":
|
|
||||||
torch.backends.cudnn.allow_tf32 = False
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
|
||||||
else:
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
if "32" in args.precision:
|
|
||||||
args.precision = 32
|
|
||||||
elif args.precision == "fp16":
|
|
||||||
args.precision = 16
|
|
||||||
else:
|
|
||||||
args.precision = "bf16"
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
from src.trainer import train_callback, generate_init_weight
|
|
||||||
from src.dataset import MyDataset
|
|
||||||
|
|
||||||
train_data = MyDataset(args)
|
|
||||||
args.vocab_size = train_data.vocab_size
|
|
||||||
|
|
||||||
if args.data_type == 'wds_img':
|
|
||||||
from src.model_img import RWKV_IMG
|
|
||||||
model = RWKV_IMG(args)
|
|
||||||
else:
|
|
||||||
from src.model import RWKV
|
|
||||||
model = RWKV(args)
|
|
||||||
|
|
||||||
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
|
|
||||||
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
|
||||||
generate_init_weight(model, init_weight_name) # save initial weights
|
|
||||||
args.load_model = init_weight_name
|
|
||||||
|
|
||||||
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
|
||||||
try:
|
|
||||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
||||||
except:
|
|
||||||
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
|
||||||
if args.my_pile_stage >= 2: # try again using another checkpoint
|
|
||||||
max_p = args.my_pile_prev_p
|
|
||||||
if max_p == -1:
|
|
||||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
|
||||||
else:
|
|
||||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
|
||||||
args.epoch_begin = max_p + 1
|
|
||||||
rank_zero_info(f"Trying {args.load_model}")
|
|
||||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
|
||||||
|
|
||||||
if args.load_partial == 1:
|
|
||||||
load_keys = load_dict.keys()
|
|
||||||
for k in model.state_dict():
|
|
||||||
if k not in load_keys:
|
|
||||||
load_dict[k] = model.state_dict()[k]
|
|
||||||
model.load_state_dict(load_dict)
|
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(
|
|
||||||
args,
|
|
||||||
callbacks=[train_callback(args)],
|
|
||||||
)
|
|
||||||
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
for n in model.state_dict():
|
|
||||||
shape = model.state_dict()[n].shape
|
|
||||||
shape = [i for i in shape if i != 1]
|
|
||||||
if len(shape) > 1:
|
|
||||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
|
||||||
else:
|
|
||||||
print(f"{str(shape[0]).ljust(5)} {n}")
|
|
||||||
|
|
||||||
if "deepspeed" in args.strategy:
|
|
||||||
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
|
||||||
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
|
||||||
|
|
||||||
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
|
||||||
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
|
||||||
|
|
||||||
trainer.fit(model, data_loader)
|
|
||||||
@ -1,104 +0,0 @@
|
|||||||
########################################################################################################
|
|
||||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
# this is for verifying the results of different models and make sure they agree with each other
|
|
||||||
|
|
||||||
import os, sys, types
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
||||||
try:
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = False
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
|
||||||
|
|
||||||
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
|
|
||||||
os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
|
|
||||||
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
|
|
||||||
|
|
||||||
TOKEN_MODE = 'pile'
|
|
||||||
|
|
||||||
if TOKEN_MODE == 'pile':
|
|
||||||
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
|
|
||||||
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
|
|
||||||
n_layer = 32
|
|
||||||
n_embd = 2560
|
|
||||||
ctx_len = 1024
|
|
||||||
UNKNOWN_CHAR = None
|
|
||||||
|
|
||||||
from src.utils import TOKENIZER
|
|
||||||
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
|
|
||||||
if TOKEN_MODE == 'pile':
|
|
||||||
tokenizer.vocab_size = 50277
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
os.environ["RWKV_JIT_ON"] = "1"
|
|
||||||
os.environ["RWKV_T_MAX"] = str(ctx_len)
|
|
||||||
|
|
||||||
from src.model_run import RWKV_RNN
|
|
||||||
from src.model import RWKV
|
|
||||||
|
|
||||||
args = types.SimpleNamespace()
|
|
||||||
args.vocab_size = tokenizer.vocab_size
|
|
||||||
args.ctx_len = ctx_len
|
|
||||||
args.n_embd = n_embd
|
|
||||||
args.n_layer = n_layer
|
|
||||||
args.head_qk = 0
|
|
||||||
args.pre_ffn = 0
|
|
||||||
args.grad_cp = 0
|
|
||||||
args.my_pos_emb = 0
|
|
||||||
model_train = RWKV(args).to(RUN_DEVICE)
|
|
||||||
|
|
||||||
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
model_train = model_train.half()
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
model_train = model_train.bfloat16()
|
|
||||||
|
|
||||||
print('loading ' + MODEL_NAME)
|
|
||||||
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
|
|
||||||
model_train.load_state_dict(m2)
|
|
||||||
|
|
||||||
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
|
||||||
model_train = model_train.half()
|
|
||||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
|
||||||
model_train = model_train.bfloat16()
|
|
||||||
|
|
||||||
args.MODEL_NAME = MODEL_NAME
|
|
||||||
args.RUN_DEVICE = RUN_DEVICE
|
|
||||||
args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
|
|
||||||
model_rnn = RWKV_RNN(args)
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")
|
|
||||||
|
|
||||||
# context = '\nIn a'
|
|
||||||
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
|
|
||||||
|
|
||||||
if TOKEN_MODE == 'pile':
|
|
||||||
ctx = tokenizer.tokenizer.encode(context)
|
|
||||||
print(f'input len {len(ctx)} data {ctx}')
|
|
||||||
|
|
||||||
########################################################################################################
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
print('\nRWKV-train output')
|
|
||||||
out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
|
|
||||||
print(out, '\n')
|
|
||||||
|
|
||||||
print('\nRWKV-RNN output')
|
|
||||||
state = None
|
|
||||||
out = None
|
|
||||||
src_len = len(ctx)
|
|
||||||
for i in range(src_len):
|
|
||||||
x = ctx[:i+1]
|
|
||||||
out, state = model_rnn.forward(x, state)
|
|
||||||
if i < 3 or i >= src_len - 3:
|
|
||||||
print(out.detach().cpu().numpy())
|
|
||||||
if i == 2:
|
|
||||||
print('...')
|
|
||||||
|
Before Width: | Height: | Size: 534 KiB |