+ bf16 mode (more stable)

main
BlinkDL 3 years ago
parent 46eebd98ca
commit 083f9504c6

@ -3,7 +3,7 @@
########################################################################################################
import numpy as np
import math
import math, os
import time
import types
import copy
@ -18,6 +18,8 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200)
### Step 1: set model ##################################################################################
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
ctx_len = 1024
n_layer = 6
n_embd = 512

@ -26,39 +26,74 @@ from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.half()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.half()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
@ -336,7 +371,12 @@ class GPT(nn.Module):
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)

@ -5,7 +5,7 @@
import types
import copy
import torch
import math
import math, os
from torch.nn import functional as F
import torch.nn as nn
@ -25,39 +25,74 @@ from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.half()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.half()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())

@ -6,6 +6,8 @@ import os
os.environ['USE_WANDB'] = '0' # 0 = False, 1 = True
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
### This is using DeepSpeed stage2 + FP16 ##############################################################
#
# Currently it's slow to initialize a new model. Hence I suggest this procedure for multi-GPU training:
@ -116,8 +118,8 @@ train_dataset = Dataset(open(
if __name__ == '__main__':
from src.trainer import Trainer, TrainerConfig
print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
print('\nmodel', model_type, os.environ['RWKV_FLOAT_MODE'], 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, '\n')
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps,
@ -132,9 +134,52 @@ if __name__ == '__main__':
from pytorch_lightning.strategies import DeepSpeedStrategy
# you can set grad_norm_clip in deepspeed.json
trainer = Trainer(strategy=DeepSpeedStrategy(config='deepspeed.json'), devices=NUM_GPUS, accelerator="gpu", precision=16)
DEEPSPEED_CFG = {
"zero_allow_untested_optimizer":True,
"zero_optimization":{
"stage":2,
"contiguous_gradients":True,
"overlap_comm":True,
"allgather_partitions":True,
"reduce_scatter":True,
"allgather_bucket_size":200000000,
"reduce_bucket_size":200000000,
"sub_group_size":1000000000000
},
"activation_checkpointing":{
"partition_activations":False,
"cpu_checkpointing":False,
"contiguous_memory_optimization":False,
"synchronize_checkpoint_boundary":False
},
"aio":{
"block_size":1048576,
"queue_depth":8,
"single_submit":False,
"overlap_events":True,
"thread_count":1
},
"gradient_clipping": 1.0,
"gradient_accumulation_steps": 1,
}
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
DEEPSPEED_CFG["fp16"] = {
"fp16": True,
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 12,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16)
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
DEEPSPEED_CFG["bf16"] = {
"enabled": True
}
trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16')
print(trainer._strategy.config)
trainer.run(m_cfg, train_dataset, None, tconf)

@ -15,6 +15,8 @@ import torch
from src.model_run import RWKV_RNN, RWKV_GPT
from src.model import GPT, GPTConfig
os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future)
ctx_len = 1024
n_layer = 6
n_embd = 512
@ -27,7 +29,13 @@ tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ')
########################################################################################################
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda().half()
model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda()
if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()
print('loading ' + model_name)
m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE)
model_train.load_state_dict(m2)

Loading…
Cancel
Save