rwkv-v4neo test
parent
50587bd65f
commit
ba6e9e6264
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,125 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define MIN_VALUE (-1e38)
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
F *__restrict__ const _y) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
|
||||
F p = 0, q = 0, o = MIN_VALUE;
|
||||
// p and q are running sums divided by exp(o) (to avoid overflows)
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
|
||||
F no = max(o, u + k[ii]);
|
||||
F A = exp(o - no);
|
||||
F B = exp(u + k[ii] - no);
|
||||
y[ii] = (A * p + B * v[ii]) / (A * q + B);
|
||||
|
||||
no = max(w + o, k[ii]);
|
||||
A = exp(w + o - no);
|
||||
B = exp(k[ii] - no);
|
||||
p = A * p + B * v[ii];
|
||||
q = A * q + B;
|
||||
o = no;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
|
||||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
const F *__restrict__ const gy = _gy + _offset;
|
||||
|
||||
F *__restrict__ const gk = _gk + _offset;
|
||||
F *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
F y[Tmax], z[Tmax], zexp[Tmax];
|
||||
|
||||
F gw = 0, gu = 0;
|
||||
F p = 0, q = 0;
|
||||
F dpdw = 0, dqdw = 0;
|
||||
F o = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
F no = max(o, k[ii] + u);
|
||||
F A = exp(o - no);
|
||||
F B = exp(k[ii] + u - no);
|
||||
|
||||
F num = A * p + B * v[ii];
|
||||
F iden = 1 / (A * q + B);
|
||||
|
||||
y[i] = num * iden;
|
||||
z[i] = iden;
|
||||
zexp[i] = k[ii] + u - no;
|
||||
|
||||
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
|
||||
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
|
||||
|
||||
no = max(w + o, k[ii]);
|
||||
A = exp(w + o - no);
|
||||
B = exp(k[ii] - no);
|
||||
dpdw = A * (p + dpdw);
|
||||
dqdw = A * (q + dqdw);
|
||||
p = A * p + B * v[ii];
|
||||
q = A * q + B;
|
||||
o = no;
|
||||
}
|
||||
|
||||
F gp = 0, gq = 0;
|
||||
o = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
F A = gy[ii] * z[i] * exp(zexp[i]);
|
||||
F B = exp(k[ii] + o);
|
||||
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
|
||||
gv[ii] = A + B * gp;
|
||||
|
||||
F no = max(w + o, zexp[i] - k[ii] - u);
|
||||
A = exp(w + o - no);
|
||||
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
|
||||
gp = A * gp + B;
|
||||
gq = A * gq - B * y[i];
|
||||
o = no;
|
||||
}
|
||||
|
||||
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] += gw * _w[_c];
|
||||
_gu[_offsetBC] += gu;
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
|
||||
}
|
||||
@ -0,0 +1,21 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
||||
@ -0,0 +1,216 @@
|
||||
from lib2to3.pgen2 import token
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
def print_rank_0(*message):
|
||||
"""If distributed is initialized print only on rank 0."""
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(*message, flush=True)
|
||||
else:
|
||||
print(*message, flush=True)
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
pass
|
||||
# with open(path, "rb") as stream:
|
||||
# while stream.read(100 * 1024 * 1024):
|
||||
# pass
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
# Little endian unsigned 64 Bit integer
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
# Little endian unsigned 8 Bit integer
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print_rank_0(" reading sizes...")
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
print_rank_0(" reading pointers...")
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
print_rank_0(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print_rank_0(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print_rank_0(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError(
|
||||
"Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
@ -0,0 +1,69 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from .binidx import MMapIndexedDataset
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
if args.data_type == "binidx":
|
||||
self.data = MMapIndexedDataset(args.data_file)
|
||||
self.vocab_size = args.vocab_size
|
||||
print("current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
self.data_size = len(self.data._bin_buffer) // 2
|
||||
print(f"data has {self.data_size} tokens.")
|
||||
elif args.data_type == "numpy":
|
||||
self.data = np.load(args.data_file).astype("int")
|
||||
self.vocab_size = args.vocab_size
|
||||
print("current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
self.data_size = len(self.data)
|
||||
print(f"data has {self.data_size} tokens.")
|
||||
else:
|
||||
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
||||
print("building token list...", end=" ")
|
||||
unique = sorted(list(set(self.data)))
|
||||
self.vocab_size = len(unique)
|
||||
# print()
|
||||
# for u in unique:
|
||||
# print(u, end=' ')
|
||||
# print('\n\n')
|
||||
xx = 0
|
||||
xxObj = {}
|
||||
for u in unique:
|
||||
xxObj[xx] = u
|
||||
xx += 1
|
||||
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||
self.data_size = len(self.data)
|
||||
print("data has %d tokens, %d unique." % (self.data_size, self.vocab_size))
|
||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||
|
||||
def __len__(self):
|
||||
return self.args.epoch_steps * int(self.args.devices) * self.args.micro_bsz
|
||||
|
||||
def __getitem__(self, idx):
|
||||
#
|
||||
# we are cheating: pick a random spot in dataset
|
||||
#
|
||||
ctx_len = self.args.ctx_len
|
||||
req_len = ctx_len + 1
|
||||
i = np.random.randint(0, self.data_size - req_len)
|
||||
if "MMapIndexedDataset" in str(type(self.data)):
|
||||
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
elif "numpy" in str(type(self.data)):
|
||||
dix = self.data[i : i + req_len]
|
||||
else:
|
||||
dix = [self.stoi[s] for s in self.data[i : i + req_len]]
|
||||
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
return x, y
|
||||
@ -0,0 +1,378 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import os, math, gc
|
||||
from re import L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
import deepspeed
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||
|
||||
########################################################################################################
|
||||
# CUDA Kernel
|
||||
########################################################################################################
|
||||
|
||||
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
|
||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"])
|
||||
|
||||
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 1024) == 0
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
w = -torch.exp(w.contiguous())
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
else:
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
u = u.float().contiguous()
|
||||
k = k.float().contiguous()
|
||||
v = v.float().contiguous()
|
||||
ctx.save_for_backward(w, u, k, v)
|
||||
y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
return y
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
return y.half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return y.bfloat16()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 1024) == 0
|
||||
w, u, k, v = ctx.saved_tensors
|
||||
gw = torch.zeros((B, C), device="cuda").contiguous()
|
||||
gu = torch.zeros((B, C), device="cuda").contiguous()
|
||||
gk = torch.zeros((B, T, C), device="cuda").contiguous()
|
||||
gv = torch.zeros((B, T, C), device="cuda").contiguous()
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
||||
else:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
||||
|
||||
|
||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class RWKV_TimeMix(torch.jit.ScriptModule):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.ctx_len = config.ctx_len
|
||||
self.n_embd = config.n_embd
|
||||
|
||||
attn_sz = config.n_embd
|
||||
|
||||
with torch.no_grad(): # fancy init
|
||||
ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
|
||||
|
||||
# fancy time_decay
|
||||
decay_speed = torch.ones(attn_sz)
|
||||
for h in range(attn_sz):
|
||||
decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||
self.time_decay = nn.Parameter(decay_speed)
|
||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||
|
||||
# fancy time_first
|
||||
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
|
||||
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
|
||||
|
||||
# fancy time_mix
|
||||
x = torch.ones(1, 1, config.n_embd)
|
||||
for i in range(config.n_embd):
|
||||
x[0, 0, i] = i / config.n_embd
|
||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
|
||||
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
||||
|
||||
@torch.jit.script_method
|
||||
def jit_func(self, x):
|
||||
|
||||
# Mix x with the previous timestep to produce xk, xv, xr
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
|
||||
# Use xk, xv, xr to produce k, v, r
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
|
||||
return sr, k, v
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
|
||||
sr, k, v = self.jit_func(x)
|
||||
|
||||
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
||||
rwkv = self.output(rwkv)
|
||||
return rwkv
|
||||
|
||||
|
||||
class RWKV_ChannelMix(torch.jit.ScriptModule):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad(): # fancy init of time_mix
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
|
||||
|
||||
x = torch.ones(1, 1, config.n_embd)
|
||||
for i in range(config.n_embd):
|
||||
x[0, 0, i] = i / config.n_embd
|
||||
|
||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
|
||||
hidden_sz = 4 * config.n_embd
|
||||
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
||||
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
|
||||
k = self.key(xk)
|
||||
k = torch.square(torch.relu(k))
|
||||
kv = self.value(k)
|
||||
|
||||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
||||
return rkv
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# The RWKV Model with our blocks
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||
|
||||
if self.layer_id == 0:
|
||||
self.ln0 = nn.LayerNorm(config.n_embd)
|
||||
|
||||
if self.layer_id == 0 and self.config.pre_ffn > 0:
|
||||
self.ffnPre = RWKV_ChannelMix(config, 0)
|
||||
else:
|
||||
self.att = RWKV_TimeMix(config, layer_id)
|
||||
|
||||
self.ffn = RWKV_ChannelMix(config, layer_id)
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
if self.layer_id == 0 and self.config.pre_ffn > 0:
|
||||
x = x + self.ffnPre(self.ln1(x)) # better in some cases
|
||||
else:
|
||||
x = x + self.att(self.ln1(x))
|
||||
x = x + self.ffn(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
# to encourage the logits to be close to 0
|
||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
|
||||
|
||||
class RWKV(pl.LightningModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
|
||||
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
||||
|
||||
self.ln_out = nn.LayerNorm(args.n_embd)
|
||||
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
if args.head_qk > 0:
|
||||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
|
||||
def configure_optimizers(self):
|
||||
optim_groups = [
|
||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
||||
]
|
||||
if self.deepspeed_offload:
|
||||
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||
|
||||
@property
|
||||
def deepspeed_offload(self) -> bool:
|
||||
strategy = self.trainer.strategy
|
||||
if isinstance(strategy, DeepSpeedStrategy):
|
||||
config = strategy.config["zero_optimization"]
|
||||
return config.get("offload_optimizer") or config.get("offload_param")
|
||||
return False
|
||||
|
||||
def forward(self, idx):
|
||||
B, T = idx.size()
|
||||
assert T <= self.args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
||||
|
||||
x = self.emb(idx)
|
||||
|
||||
for block in self.blocks:
|
||||
if self.args.grad_cp == 1:
|
||||
x = deepspeed.checkpointing.checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_out(x)
|
||||
|
||||
if self.args.head_qk > 0:
|
||||
q = self.head_q(x)[:, :T, :]
|
||||
k = self.head_k(x)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (1.0 / self.args.head_qk)
|
||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size)
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
c = c @ F.one_hot(idx, num_classes=self.args.vocab_size).bfloat16()
|
||||
|
||||
x = self.head(x) + c
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
idx, targets = batch
|
||||
logits = self(idx)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
|
||||
self.trainer.my_loss = loss.item()
|
||||
self.trainer.my_epoch_loss = loss.item()
|
||||
self.log("lr", self.trainer.my_lr, prog_bar=True, on_step=True)
|
||||
self.log("loss", self.trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
||||
|
||||
return L2Wrap.apply(loss, logits)
|
||||
|
||||
def generate_init_weight(self):
|
||||
print(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# Init model weight (slow for large models)...
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
m = {}
|
||||
for n in self.state_dict():
|
||||
p = self.state_dict()[n]
|
||||
shape = p.shape
|
||||
|
||||
gain = 1.0
|
||||
scale = 1.0
|
||||
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n:
|
||||
m[n] = p.cpu()
|
||||
continue
|
||||
elif n == "emb.weight":
|
||||
scale = -25 * self.args.lr_init
|
||||
else:
|
||||
if shape[0] > shape[1]:
|
||||
gain = math.sqrt(shape[0] / shape[1])
|
||||
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]:
|
||||
if kk in n:
|
||||
scale = 0
|
||||
if n == "head.weight":
|
||||
scale = 0.5
|
||||
if "head_k." in n:
|
||||
scale = 0.1
|
||||
if "head_q." in n:
|
||||
scale = 0
|
||||
|
||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
||||
|
||||
if self.args.accelerator.upper() == "GPU":
|
||||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
||||
else:
|
||||
m[n] = torch.empty((shape[0], shape[1]))
|
||||
|
||||
if scale == 0:
|
||||
nn.init.zeros_(m[n])
|
||||
elif scale < 0:
|
||||
nn.init.normal_(m[n], mean=0.0, std=-scale)
|
||||
else:
|
||||
nn.init.orthogonal_(m[n], gain=gain * scale)
|
||||
|
||||
# if n == "emb.weight":
|
||||
# print(m[n])
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return m
|
||||
@ -0,0 +1,209 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
print()
|
||||
import os, warnings, math, datetime
|
||||
import numpy as np
|
||||
from argparse import ArgumentParser
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.callbacks import TQDMProgressBar
|
||||
from pytorch_lightning import Callback
|
||||
|
||||
seed_everything(42)
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
||||
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
||||
|
||||
########################################################################################################
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
parser.add_argument("--wandb", default="", type=str)
|
||||
parser.add_argument("--proj_dir", default="out", type=str)
|
||||
parser.add_argument("--n_layer", default=6, type=int)
|
||||
parser.add_argument("--n_embd", default=512, type=int)
|
||||
parser.add_argument("--pre_ffn", default=0, type=int)
|
||||
parser.add_argument("--head_qk", default=0, type=int)
|
||||
parser.add_argument("--lr_init", default=6e-4, type=float)
|
||||
parser.add_argument("--lr_final", default=1e-5, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument("--epoch_steps", default=1000, type=int)
|
||||
parser.add_argument("--epoch_bias", default=0, type=int)
|
||||
parser.add_argument("--epoch_save", default=5, type=int)
|
||||
parser.add_argument("--beta1", default=0.9, type=float)
|
||||
parser.add_argument("--beta2", default=0.99, type=float)
|
||||
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
||||
parser.add_argument("--ctx_len", default=1024, type=int)
|
||||
parser.add_argument("--micro_bsz", default=12, type=int)
|
||||
parser.add_argument("--data_workers", default=1, type=int)
|
||||
parser.add_argument("--grad_cp", default=0, type=int)
|
||||
parser.add_argument("--load_model", default="", type=str)
|
||||
parser.add_argument("--data_file", default="", type=str)
|
||||
parser.add_argument("--data_type", default="utf-8", type=str)
|
||||
parser.add_argument("--vocab_size", default=0, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.enable_checkpointing = False
|
||||
args.logger = False
|
||||
args.gradient_clip_val = 1.0
|
||||
args.num_sanity_val_steps = 0
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
args.proj_dir = args.proj_dir.strip().strip("\\/")
|
||||
|
||||
samples_per_epoch = args.epoch_steps * int(args.devices) * args.micro_bsz
|
||||
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
||||
rank_zero_info(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# RWKV-4 {args.precision.upper()} on {args.devices} x {args.accelerator.upper()} {args.strategy.upper()} {'with grad_cp' if args.grad_cp > 0 else ''}
|
||||
#
|
||||
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
||||
#
|
||||
# Epoch = {args.epoch_bias} to {args.epoch_bias + args.max_epochs - 1}, save every {args.epoch_save} epoch
|
||||
#
|
||||
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
||||
#
|
||||
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
||||
#
|
||||
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, β {args.betas}, eps {args.adam_eps}
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
rank_zero_info(str(vars(args)) + "\n")
|
||||
|
||||
if not os.path.exists(args.proj_dir):
|
||||
os.makedirs(args.proj_dir)
|
||||
|
||||
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx"]
|
||||
assert len(args.data_file) > 0
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0:
|
||||
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule.\n\n")
|
||||
|
||||
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
||||
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
||||
if args.precision == "fp32":
|
||||
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
||||
if args.precision == "fp16":
|
||||
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
||||
|
||||
import torch
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if args.precision == "fp32":
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
else:
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if "32" in args.precision:
|
||||
args.precision = 32
|
||||
elif args.precision == "fp16":
|
||||
args.precision = 16
|
||||
else:
|
||||
args.precision = "bf16"
|
||||
|
||||
########################################################################################################
|
||||
|
||||
class train_callback(pl.Callback):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
args = self.args
|
||||
g_step = trainer.global_step
|
||||
|
||||
# logging
|
||||
if trainer.global_rank == 0:
|
||||
if g_step == 0:
|
||||
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||
trainer.my_log.write(f"NEW RUN {datetime.datetime.now()}\n{vars(self.args)}\n")
|
||||
trainer.my_log.flush()
|
||||
if len(args.wandb) > 0:
|
||||
print("Login to wandb...")
|
||||
import wandb
|
||||
|
||||
model_name = str(args.vocab_size) + "-" + str(args.ctx_len) + "-" + str(args.n_layer) + "-" + str(args.n_embd)
|
||||
wandb.init(project=args.wandb, name=model_name + "-" + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), config=args, save_code=False)
|
||||
trainer.my_wandb = wandb
|
||||
|
||||
# LR schedule
|
||||
w_step = args.warmup_steps
|
||||
if g_step < w_step:
|
||||
lr = args.lr_init * (g_step / w_step)
|
||||
else:
|
||||
progress = (g_step - w_step) / (args.max_epochs * args.epoch_steps - w_step - 1)
|
||||
progress = min(1, max(0, progress))
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||
else: # exp decay
|
||||
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
||||
|
||||
for param_group in trainer.optimizers[0].param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
trainer.my_lr = lr
|
||||
# rank_zero_info(f"{g_step} {lr}")
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
args = self.args
|
||||
# logging
|
||||
if trainer.global_rank == 0:
|
||||
if len(args.wandb) > 0:
|
||||
trainer.my_wandb.log({"loss": trainer.my_loss, "lr": trainer.my_lr}, step=trainer.global_step)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
args = self.args
|
||||
if trainer.current_epoch % args.epoch_save == 0 or trainer.current_epoch == args.max_epochs - 1:
|
||||
torch.save(pl_module.state_dict(), f"{args.proj_dir}/rwkv-{args.epoch_bias + trainer.current_epoch}.pth")
|
||||
trainer.my_log.write(f"{args.epoch_bias + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
||||
trainer.my_log.flush()
|
||||
|
||||
@rank_zero_only
|
||||
def generate_init_weight(model, temp_name):
|
||||
try:
|
||||
os.remove(temp_name)
|
||||
except:
|
||||
pass
|
||||
mm = model.generate_init_weight()
|
||||
print(f"Saving to {temp_name}...")
|
||||
torch.save(mm, temp_name)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from src.dataset import MyDataset
|
||||
from src.model import RWKV
|
||||
|
||||
train_data = MyDataset(args)
|
||||
args.vocab_size = train_data.vocab_size
|
||||
|
||||
model = RWKV(args)
|
||||
|
||||
if len(args.load_model) == 0:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth" # init weights to tmp file
|
||||
generate_init_weight(model, args.load_model)
|
||||
else:
|
||||
args.load_model = f"{args.proj_dir}/{args.load_model}"
|
||||
|
||||
print(f"\nLoading {args.load_model}...\n")
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
model.load_state_dict(load_dict)
|
||||
|
||||
trainer = Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[train_callback(args)],
|
||||
)
|
||||
|
||||
train_loader = DataLoader(train_data, batch_size=args.micro_bsz, num_workers=args.data_workers)
|
||||
trainer.fit(model, train_loader)
|
||||
Loading…
Reference in New Issue