BlinkDL 3 years ago
parent 6ed3a3db09
commit 7476c69f32

@ -14,6 +14,10 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
try:
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
except:
os.environ["RWKV_MY_TESTING"] = ''
def __nop(ob): def __nop(ob):
return ob return ob
@ -346,6 +350,14 @@ class RWKV(pl.LightningModule):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
if not hasattr(args, 'dim_att'):
args.dim_att = args.n_embd
if not hasattr(args, 'dim_ffn'):
args.dim_ffn = args.n_embd * 4
if not hasattr(args, 'tiny_att_layer'):
args.tiny_att_layer = -1
if not hasattr(args, 'tiny_att_dim'):
args.tiny_att_dim = -1
self.emb = nn.Embedding(args.vocab_size, args.n_embd) self.emb = nn.Embedding(args.vocab_size, args.n_embd)

Loading…
Cancel
Save