no message

main
BlinkDL 3 years ago
parent 8d4fed7128
commit 61b7c429df

@ -18,6 +18,8 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' or 'fp16'
os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
ctx_len = 1024
n_layer = 6
@ -45,7 +47,6 @@ else:
### Step 3: other config ###############################################################################
RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False # True False - show softmax output
NUM_TRIALS = 999

@ -18,6 +18,7 @@ DEBUG_TIME = False # True False - show trained time-coeffs
# CUDA Kernel
########################################################################################################
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
T_MAX = 4096 # increase this if your ctx_len is long
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice

@ -9,14 +9,14 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
RUN_DEVICE = 'cuda'
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
os.environ['RWKV_RUN_DEVICE'] = 'cuda'
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
ctx_len = 1024
n_layer = 6
n_embd = 512

Loading…
Cancel
Save