os.environ['RWKV_FLOAT_MODE']='bf16'# 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
os.environ['RWKV_RUN_DEVICE']='cuda'
RUN_DEVICE=os.environ['RWKV_RUN_DEVICE']
importtorch
fromsrc.model_runimportRWKV_RNN,RWKV_GPT
fromsrc.modelimportGPT,GPTConfig
os.environ['RWKV_FLOAT_MODE']='bf16'# 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)