|
|
|
|
@ -4,8 +4,8 @@
|
|
|
|
|
|
|
|
|
|
import os, math, gc
|
|
|
|
|
import torch
|
|
|
|
|
torch._C._jit_set_profiling_executor(True)
|
|
|
|
|
torch._C._jit_set_profiling_mode(True)
|
|
|
|
|
# torch._C._jit_set_profiling_executor(True)
|
|
|
|
|
# torch._C._jit_set_profiling_mode(True)
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
@ -42,7 +42,7 @@ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
|
|
|
|
from torch.utils.cpp_extension import load
|
|
|
|
|
|
|
|
|
|
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
|
|
|
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-lineinfo", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
|
|
|
|
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
|
|
|
|
class WKV(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, B, T, C, w, u, k, v):
|
|
|
|
|
|