no message

main
BlinkDL 3 years ago
parent 8d4fed7128
commit 61b7c429df

@ -18,6 +18,8 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' or 'fp16'
os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
ctx_len = 1024
n_layer = 6
@ -45,7 +47,6 @@ else:
### Step 3: other config ###############################################################################
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
NUM_TRIALS = 999

@ -18,14 +18,15 @@ DEBUG_TIME = False # True False - show trained time-coeffs
# CUDA Kernel
########################################################################################################
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}'])
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
@ -59,7 +60,7 @@ if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
@ -94,7 +95,7 @@ elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
############################################################################################################

@ -9,14 +9,14 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
RUN_DEVICE = 'cuda'
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
os.environ['RWKV_RUN_DEVICE'] = 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
ctx_len = 1024
n_layer = 6
n_embd = 512

Loading…
Cancel
Save