+ fp32 mode (slow but good for verification)

main
BlinkDL 3 years ago
parent 94f618c52a
commit 73b96705d7

@ -34,25 +34,25 @@ class WKV(torch.autograd.Function):
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if os.environ['RWKV_FLOAT_MODE'] != 'fp32':
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
else:
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
return y
@staticmethod
def backward(ctx, gy):
@ -66,18 +66,18 @@ class WKV(torch.autograd.Function):
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if os.environ['RWKV_FLOAT_MODE'] != 'fp32':
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
else:
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
return (None, None, None, gw, gu, gk, gv)
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
@ -356,12 +356,12 @@ class GPT(nn.Module):
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
if '32' in os.environ['RWKV_FLOAT_MODE']:
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16()
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size)
x = self.head(x) + c
else:

@ -34,25 +34,25 @@ if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if os.environ['RWKV_FLOAT_MODE'] != 'fp32':
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
else:
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
return y
@staticmethod
def backward(ctx, gy):
@ -62,22 +62,22 @@ if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
if os.environ['RWKV_FLOAT_MODE'] != 'fp32':
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
else:
gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
elif os.environ['RWKV_FLOAT_MODE'] == 'fp32':
return (None, None, None, gw, gu, gk, gv)
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
@ -222,7 +222,13 @@ class RWKV_GPT(nn.Module):
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float()
if '32' in os.environ['RWKV_FLOAT_MODE']:
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)

@ -16,7 +16,12 @@ from pytorch_lightning.lite import LightningLite
import gc
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

@ -12,9 +12,6 @@ from src.binidx import MMapIndexedDataset
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
# if False: # True False ---> Set to False if you don't understand it
# print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n")
@ -61,7 +58,12 @@ if EXPRESS_PILE_MODE:
#
os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) or 'fp32'
#
# 'bf16' (fast & stable)
# 'fp16' (fast & will overflow after training a large model for very long. can be solved in the future)
# 'tf32' (decent speed & stable)
# 'fp32' (!!!very slow!!! only for verification)
os.environ['RWKV_FLOAT_MODE'] = 'bf16'
os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True
@ -159,6 +161,14 @@ if EXPRESS_PILE_MODE:
betas = (0.9, 0.999)
MODEL_NAME = EXPRESS_PILE_MODEL_NAME
torch.backends.cudnn.benchmark = True
if os.environ['RWKV_FLOAT_MODE'] == 'fp32':
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
########################################################################################################
# Load data
########################################################################################################

Loading…
Cancel
Save