no message

main
BlinkDL 3 years ago
parent 8d4fed7128
commit 61b7c429df

@ -18,6 +18,8 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' or 'fp16'
os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
ctx_len = 1024
n_layer = 6
@ -45,7 +47,6 @@ else:
### Step 3: other config ###############################################################################
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
NUM_TRIALS = 999

@ -18,84 +18,85 @@ DEBUG_TIME = False # True False - show trained time-coeffs
# CUDA Kernel
########################################################################################################
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}'])
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.half()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}'])
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.half()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
############################################################################################################

@ -9,14 +9,14 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
RUN_DEVICE = 'cuda'
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
os.environ['RWKV_RUN_DEVICE'] = 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
ctx_len = 1024
n_layer = 6
n_embd = 512

Loading…
Cancel
Save