From aa4e2a68f43f4e85aacabf8bd84d19ec6997b310 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Mon, 9 Aug 2021 13:52:19 +0800 Subject: [PATCH] first commit --- .gitignore | 7 ++ RWKV-vs-MHA.png | Bin 0 -> 9634 bytes src/__init__.py | 0 src/model.py | 290 ++++++++++++++++++++++++++++++++++++++++++++++++ src/trainer.py | 128 +++++++++++++++++++++ src/utils.py | 46 ++++++++ train.py | 117 +++++++++++++++++++ 7 files changed, 588 insertions(+) create mode 100644 RWKV-vs-MHA.png create mode 100644 src/__init__.py create mode 100644 src/model.py create mode 100644 src/trainer.py create mode 100644 src/utils.py create mode 100644 train.py diff --git a/.gitignore b/.gitignore index b6e4761..2de160e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,10 @@ +*.txt +*.csv +*.pth +*.xlsb +*.xlsx +*.xls + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/RWKV-vs-MHA.png b/RWKV-vs-MHA.png new file mode 100644 index 0000000000000000000000000000000000000000..21fad12f7fca0684213a35c8e069222475fc3a40 GIT binary patch literal 9634 zcma)iWmr^S^zI==5b5rc2I&qF5T#ofy1ToTa!BbOKuWs1q@|moq(eG~jywMT_rv{k z@15tFIs2Sv*V+5sYrX581y)h|fQe3u4gdh=M>%OV06>U`m&?%*;lBi}4#x1yOD8!U zR{+4k`}aiv(ldzxfExHH{Z7Lx^XQMam%8SY#OZi@^Fn(5eaA%d@uC{ZA_?=?HN<_U zh=}2Di2FG|zEQvG{u>m-G{nu~u%>**n)mDX&NpgmLR4ls5RUF4Vmzy4v&??0BSm$Sae0*=Ov?Ll0*1Ogg0F8$CEinjw6I>mN zeg1zd_6S~abuD}OP5T;BR9fNq#TRnuiQ4B!x!5#gOZ~)pJ=cJ&ui7F0Sj{j@w8Fx= z(qRLl*j+ZE2~ja5FSDNvz}p#p0j%{DBKu#<=!;-cyIB0!PJ6zQp)Ol&Sx#mJd&B5~ zCS*cVUv>GkTjmwqpzjVV?_DM3v)o830bXxjmug@G9im{aPB|$bo%6S)ybFnk{EN5)e&?8lWOz zLWe`c3&86)p`ieYnvJbhuh3{>H>YGx-tyz(NFmLw`hYSUZZbS21m(u{YI)xq*BiAN zl?YvSwb&Q!U($Bww-O50ss2zKyD6n7E&o_#ndGu8jYMNclu^%K;jr|hxip|@VtH|* z+n|A0Y_-mE+tSh$qVxE)=>0NjX?10GZC)0fuT421NgZ-%$oIeA| zX=*F3n>jT4s8<5bICt|>r!>C#gEQ~19@CecGpYmYTFBQH_o2}#k*KveE}0E?{Ons= zIZCQsDUyQ{mL@)*QEFWz<~-e4>Q^a@#r@!518*(pm@WOqQc<>D72B)5%WMZ_jhE5$ zPTYx|G7Y;G0Pa8 zH<2Crb2+;_yb5X-!m~>M+B>lF^njDl^_zz~Em)JCc1BrC>rX&#E|Mt*jae+hi`O#K z`uZn{a13Cs_F(D1Cm7qKScYhbjt^{hkfO(I%#X&jTL}-bJIzhEo zW2QV7Bar8(w3y@YqiGV5zA1H8+&Ap~lu>iU<`(-%!V{h$#c~7oGQg;Q0A&jl za@s&wjCp=}W`n`w^dL_b0Voy6^DILA*RUoy_@4eeovP6g&F9c@y5i*-R+K{2wWm|I zv@VI{phjE8g0dYwUdfp;90mLt)iQ8=KltdkC)(fC%UDUA-nB7Wj0X7Z4Gd)H43L72 zW?8dw%MpWWNdK+1opYJI!5Zcp1pHzw4ufW!!Sch&Yv|V1-}{;rCf=82s97c=@V&|~ zy{A{B<>b`OzuWwos`4`Mg8hq#$K%RatY(Cre_>@bSwW@(ic|Grlyn?A5W0Jy_TK$< z(hy3;4m-ct@M&xDy61#I(& zwy(B6b-)oC(Km@?Yctm(V%k>lX^RDKyQ1?|aNxTz>%zh_-XC~8#UkT^YA2K-=|mOk2Vc zNWE1L8RvtkA>IqX?_d&2)aF0b#lREyxXjQXS1G!xMn_zSH;sg_c3P?U-(54UhO zY`30}31LU=HfG|l^0L39cMJV`+lc8{satO|Thb^lXk5(xvj1WUV^2b%^XwS($h%5q zw*^S92`2NTd95j0)TV6@Ht!6zw|e^Du9S$id|P6?tv@zAyNcbge*WC(v4?*=kuGp< zLg~My%Z03{U$vEXCXSD`qh;D1JUrLl>EkisXxMfgS&>_Kyga#6J8xW z84{^m2V&uPTPj{IIU5;1BJtdeL!K=8(>$v^#uS3q7O2W2kb`3OH`=;?_}qp-9WZz^ z1vO6ij3@Ksx@d1M4sAW#e419Si>bM0DTxGVrO|<$ww#)`s6a^h<_ z+>`-;uN_9d%(nOj{^#^UyMG&Jc_2P6J*=A%NJHWQT$XbmDx>2kU-b|meN%tBn~KEg ztPI?lg-I-TW?H(ZjU-g)I0=zWc{&A0#MCdU0h;KEqh!o&s>Mc^)+EoKw|wER(l4O)C>9P}*6 zp9J;YeHR+-9HuZ-bRhLxtK%iA^$wkxJw3V=7bZSKp*7d6&Pi?`Z{J($zmZoO!Yq#OmzAJ4`TkiQ6!EyCWEcJf2`NEOH(VM8PC z0dp-1h>lG90=E6Oz56usbLJ`M6^_+-6Ca|0<< z#NUeaf$8w+U+MtAO2*O3Ube94k(OIuAk<%UAk#4BNxKv%gyiy>{cD1(zj9Z8$9%zv z;*xoRilWXdi&7=@tbyV4|!pbxW4SA0f_MCE(0rE zs7~3$a|k~$4GRQP^Vk4(sgM9X^_$Jm%e{+oLqFf0JNh&_6iQz03j=qE`OTl1(^fj- zEq~l1gNyDV0@4J)Nbvp))S5WGtMJmzExQ+T>5jEmD-e=+qHWNb+;)*zYz&)oZ4H=P z72;YjWT?In4nAU8EgcJf)tNyW;E9aX%$#A|QzG$nP{wF@*MXVdz~<$5Ric)#z6__| z9pxqr@2bHc|9zl(WWVHSE27VcL@M8_rF%;-sHa8LL>E2DucT#?WTqV!P1o8sKs*m%8~L4qPcJ8wDD0G3k?^Or{x-DDpB<4Y%rL1My#m;Ln4Z9mO522n4{naz zc%ITC7x&mwl-m(k9}gcfEx+e0#K%BbstT_vQ_xY_0eG;*YQZQ^>~oOILLL+`F*_2y zL+`Qa4&{Es1}ZwK_~SaWm%vB2#POGjO|dq54%p+%fR}{)r{SO-t2CN0F=Kv4e-m0-s-lD z>`{*$b+_r%yL=Goj63q4-w)(WZ`~M(6a|7NoP936FwtWZvE|Eoy#Yqm zML!!y0yt&6n9R1({dE0MvWb8^eX^R_@t2@9zqR3ltF*E0yldfmI8_sejFqT&U6?}8 zF38mbpCoVd_cj6LX?X127D=br!Pb<2d!j-egje`)v=EeOd*o^Pjvym~t2r!dm>(KP z|A%xad571T+AZb=##jt6;v)+P;@oe{L3>7I5mO=;&X) znyCVQH9~GW27RXzbH?Ob4-3EJdCA&+(R4{0-~QXcMn9NDxm&xFHlq@#<8am|Lq%v) z*HrrFbYgf%IES2XvQzB2E?Gd6Vm9Q@y;E47PhbJqR{SbM=%FEGOvZG?a>cjv?Nu|h z@B^4R$A2LDf@eDbgaVW^UJ2Z;__$*4%sQ>f+74bv(rYhSx{mC+)MbxojIPKHYB=LW z@QvqlJZ{;k7F@w0#hwwDns&ImA8){Z`2u)^71Qjk}~QdkZEH9V!T9brj~z%~O| zcAoj7ZD6t5)7j#U?S}Cz)_!5HsQwxS)XP1ZBGH@22e#)ko!w@RKd}Oh!dh;92JH-z zNCeLRl`vNe8hEfIUDz=0?00z7YpWmuq#g>WI=%iRj$mg&^sfdTvyaPrQ8y>vYHtbM ze%;>snuvvI=9&dv)ejK6?u1ELSz)6n5VMyGIqmm4N^Ma26~r}4r;hp6o*31R-%KV) zGxiQDbGAkc)wt+)#Jm6&S>Y^xFGZgb>&OedgLG z<&K%xQR&IgZ=qeZpKv`*jqZ0}EGDVfJa9fHGn+3PCQIBLbbD%1`kcT(j|mkS0fv{y zo&tv_r3K^$FqnfUjSz;edKeXq3`t&Kx2qmf~97{P z#EF^asj1%HkYVx7_2p;i7I5!4bU_Lx#HTxJu}x#*uHRe|cs0;zRyxL1*#0Rwk~oD_ z@~U3^-1egS>(2BpXz*!Jwd?%SD?5sQHPI^RzHBMKI1ac@4?RHTS$u1_zu3!>*qzcJHOKZ!ny?O z`P`?POTlba=M|;O(A+L5E)()!L1FLUHVQ9!i^o%bprgKD{dty3&yhp?(0%uxMI_f8 zN*he;O5h1b$oZ+}d*^qMwz?F-QtwXsKC8b)CS_$E(Cp0Bbd$D`87{{N-#A zv_BI#32^!jX|}Kq+fbtby%nHbyD^uKE4Is8 za*p!C8JH8y@OD3UN6LY<3ZzRZTS$N$xIXh!Ekk=M9^_XIFPF9W?>^f+XvUZEK8sXh zMMtVM?ZAh9tR?~J{b`Viye8Q$)T$)Z0;(d||F;s#llBk`bug^AKm5RB?-;~I-}_QW zeEej_(vf#3z@BA@pTJyR?aiP=YM*ogOjL+N3Wq4SiH8**aEavH$(-FHX#B2Nb-gtc zEsyEt=7JwaB<1ki!cJ8jkFdkDaqHLD(1BKp+{ZT~8CbxtaK)E*$B``_D~g5Qs@zG? zBhI1zac(omw_6U3>kQ~oaWRjP^yawbul~Z+(*^^~?EH1y?;&{ZFO>MIkv3oopozAyIB>py4u{;)vfLcRwxr;2I>zzA&f1EgZQ0{=ib| z044bh{4C*n5tddZLX}H1_uRx@;k@L^N#1$VHBz##ko@J0ILgdIXGSK zSzlk**3*6Qk;TPK(|aR*I&p7mZZ6YPkj$r?%Px_br537eUH7o^5Yu*nT083B*q-ku z9mK>P1b%6OOGm%rLO%0tV$G#3xz`il@1_xpWtyMGA!HYK1_Iy)Q!HEhCAAg&~Gk(IYVOlU(#;v46A}wbwl1y;b%A2l*s}8_7^yGVA zc_1e*FHc9_x~|Zb@1dEL=N65Gl;0-b6CPwUYV}r}J4ZoG{*xJ-tPPi1SNyZeHdf_V zMA!YtBL?tt^r18t=eWL-?7=VP;qsc%gB@Tri?S0T>bj25(9rw2YV~tfCxv7@HyBLC zs+eIi$QhWlHCCoF<+_S_TV!W)`{R>6d`Hu28m7%tFfr*iS-9iUNa_Hd@Rs20bVA$i62WuJk)0p)H?b4%*Vj*QSa1p!WX2n4HGR7%*2q@oX z4PdJsACp?K2>-qIBNwv)V@nxu1`YmN#m0}AZUc^Ce$ zb3HTe$}yis+49#VKm#Z!DB`vFDs6Y%NA-tY5aCfk%wW!YygpTH|9Xzx(hp_O_wcE_ z`9h>=<4I9IDn2V077G(ixXg+>es(TK&?$aO?lLDNJu+vXXmUb9B|mq&AzH?Qzt3=+ zVX}AN`1|^T0x%fdR*_c9Zd{5!AK?c_MkxpNo~-z=bTqC^A$Oeh#(dP50KUDxz`Ntr zOvGr|G3DNw_$LxPqFieGk5iAU+I0)blqYxuY*=SWg}p!DPJS=?B{9t^!;17bgw0us z(A%10+?4fvp#)J&|JOmfqRhfX{Y{A=e=wM&r#H&@hb(L{`H+C*%j5w~yzR9TS|yZt zuyHXOO}4o4+Q1d?xeFI&skZA z#FNNP)iHYnSbyuRtL!r8|Ng}1-rmvCCl?0UyF3&U%PcKM2pWAcg9O(hn^o!B2j8oa z_kTgf>Yt;=Md^ntpGKD&0QP6RB~2WjlwD^<0zGw^%Itm7#8&-4r*_^8@x8yrSFfmC zgUox=wkfh8VlFk(*e|}v1mX=tZKK>tgJDG*v>dkb@gK4QNgMk)cY(kpdKPJMn|y)d zGh;e@A&e&yOdPnzTD}+$2fMa*cjF8qa6~41A2wYgS@VMz1nSn>B60O}QY5$wKEPZc zOPzo9hFcLoU5-nwp=51Jh6F75CJ$=vO`ey?^U`>uxA)8JBHEM}Td70mPo9L93|ZSL zE5$$k-rtK=pIAPceIMzcCFMxp=OAU%Mg_P>*g5oazk}u6u3!pCn7+iIibA9Z! zf$ugd<~w?5N#ZQ~qW^mbKO_EPMCYfC%J@t2BrUCJ2W#D{&=8z=8)2Es7Wf)`q`bO# z-VGDlnXm~&{UfcH?-^rN7CF2%BlBoNtGADt(uO=%+kT~ghLJSTW|#lA zoV7%rGxJZp!X&!YCl&A4FIxxM{2u=!n+fMONrmyqe@hcmfAfASteJeEPZE3~>;9S4 z(@bCX9yUu)AFb}*Zji@UPu5hEethxspZf^L8HSUb{6FYGb05(QU6&@mPg;$KwMs1z z`d88JVvODM=U(`9)d7Ci5*-C$wq-`8=uItZ@M1Z;@hplO;J;veFwNFwN#UIm6VC!N zBc&i_M~`wh`H+$2!>{C9qu^1iphO+3crX0?W6aKaPhsFzry%1eok$dvBGKoDF2DCPLcAxg=H2OqI3hk<~rqhP$wWY*HByhb%D z5v+hDJ9YbvL{h~6lPv*{{9vpR;*v96JIAwCo=tNpP1+~M_+#-Svn6-J<8N9AkC=G6 zXy+!sRLG;9Hf7hWy=4@;+e^G8iy8TCF^lz#;R`0&MO(#)I>P1_+=^pzZEdL%yDhpu zWAw2Y(z#$ja%1FfwmJ`1C~l)(y`2odL2a*ER|Z7?+{CkO0ZX6+eV%uT=j)H zzrwQBFg8l0x{W^CFbe{=cB%d)afviyD~ysA&`1u3oPv%kU&ZjTH6uacppj?BBXfBS z;hD;NG)vn*78x=q!xW#HjecQ31^ej;D@cBp`kj66_)fsx)PmcUvzfsnS8>*Ype$O0 z*C&_yQ?-se=ylXy*{XYTj~NM@nq#soeg13bexbOQcY$|)wgw-E_iD7r>_N8qE%}LD zX1GstUd){UfFk!V7eG=+Ai;OGyk?IL&Xw3RCZDrc31%AZ=B04-w}0aeLj3LN$K=c- z9E04{-Qd?p`s9{z6_01QCs1TF)X23La7ot9k4`RznPUDe zQPMS7)KPZhZMa*J!pSv&5m5z zv-l}H-&Pz`+7BZ}ZpTC@e_XM)o_b%<*UBX;<8#TLs}C(OBw_OFp7^}#*=NECv}U^V zhs#eD)Wmgm_Je@9WQkE3xSWL=Z)b25`Xn)5?}7H6qYZAqr3j*D>HQ4O)Gr8V}b%q+z}p3hsDU*MGWyuTR6RYnMwPN zzb-J2?taBldf&Xn6K4j&@`n>?qt(d-{6rzkP77ALKsYepdI2`ZRdNAP_-vrsv z@9Tku8!Hpai)1G{*6SQoqgfMlXU$h<)p{Cd;s@U3>noJ#Z>h~G^URkZ)|hAK(|%B! zysuk8$Tef>H2-iXHo88)Umdy;i(whS9AT00#D3);yWBlxL+P) zzmKRH-HmmvXMq2Q*|Ix!9yL!nIyy?Thj%12J^5cBh2|e2^~$tTlo3~$0{x1ILM-!| zR;x}6^B>GishYbF+!B)`@K9kNX?aYJ3oCboZ;V5*UC76#r2l^sBL1Hl6aPnrr2~8C z@}iEhWG>`~#fxk%T*t4Ui#%Fo^zR37m(Az@+VbvS~QIkki` zT=)Q5N=(`!W}Dpl347k?>}@B1wbTg|o|eN{tGe`K>zxZz*)&E@{>>GJu!=N`h`Q3m zPTAd>-4rWT`**h4Y#koY+5H}Yvi4c1MRaLh-Ao#aCzhA`zX0Q(#9DP6Tp)i$Z=L(wB0fca){OLP3Uatf)@aN(hhmfB@(kkbpP&#D_G0(;Hgo<}Ln z1!`i@gk%;xAL(0HD_Fqdf3TCtfazT-yO{8$&|sJcAOC++lm3s~rHf|-?0123nLWHb R@T4o?ql}Vtg_P0P{{=L^&dmS- literal 0 HcmV?d00001 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..40fdbe9 --- /dev/null +++ b/src/model.py @@ -0,0 +1,290 @@ +import math +import logging +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +logger = logging.getLogger(__name__) + +######################################################################################################## +# Block: RWKV Time-mix + RWKV Channel-mix +######################################################################################################## + +class RWKV_TimeMix(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.ctx_size = config.ctx_size + self.n_head = config.n_head + self.head_size = config.n_embd // config.n_head + + self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_size)) + self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_size)) + self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_size, 1)) + self.time_gamma = nn.Parameter(torch.ones(config.ctx_size, 1)) + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + + self.time_shift = nn.ZeroPad2d((0,0,1,0)) + + self.key = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + self.receptance = nn.Linear(config.n_embd, config.n_embd) + + self.output = nn.Linear(config.n_embd, config.n_embd) + + def forward(self, x): + B, T, C = x.size() + TT = self.ctx_size + w = F.pad(self.time_w, (0, TT)) + w = torch.tile(w, [TT]) + w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) + w = w[:, :, TT-1:] # w is now a circulant matrix + w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] + w = w.masked_fill(self.mask[:T, :T] == 0, 0) + + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + k = self.key(x) + v = self.value(x) + r = self.receptance(x) + + k = torch.exp(k) + sum_k = torch.cumsum(k, dim=1) + + k = k.view(B, T, self.n_head, self.head_size) + v = v.view(B, T, self.n_head, self.head_size) + + wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C) + y = torch.sigmoid(r) * wkv / sum_k + + y = self.output(y) * self.time_gamma[:T, :] + return y + +class RWKV_ChannelMix(nn.Module): + def __init__(self, config): + super().__init__() + self.time_shift = nn.ZeroPad2d((0,0,1,0)) + + self.key = nn.Linear(config.n_embd, 3 * config.n_embd) + self.value = nn.Linear(config.n_embd, 3 * config.n_embd) + self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + self.receptance = nn.Linear(config.n_embd, config.n_embd) + + def forward(self, x): + B, T, C = x.size() + + x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) + k = self.key(x) + v = self.value(x) + r = self.receptance(x) + + wkv = self.weight(F.gelu(k) * v) + y = torch.sigmoid(r) * wkv + + return y + +######################################################################################################## +# Block: Multi-head Attention + Rotary Encoding + GeGLU FFN +######################################################################################################## + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_len=None): + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), -1) + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotaryMHA(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.ctx_size = config.ctx_size + self.head_size = config.n_embd // config.n_head + + self.query = nn.Linear(config.n_embd, config.n_embd) + self.key = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + + self.register_buffer("mask", torch.tril(torch.ones(config.ctx_size, config.ctx_size))) + + self.rotary_ndims = int(self.head_size * 0.5) + self.rotary_emb = RotaryEmbedding(self.rotary_ndims) + + self.output = nn.Linear(config.n_embd, config.n_embd) + + def forward(self, x): + B, T, C = x.size() + + q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) + + q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] + k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] + cos, sin = self.rotary_emb(q, seq_len=T) + q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding + q = torch.cat((q, query_pass), dim=-1) + k = torch.cat((k, key_pass), dim=-1) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) + att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask + att = F.softmax(att, dim = -1) # softmax + + x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) + x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + + x = self.output(x) # output projection + return x + +class GeGLU(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.key = nn.Linear(config.n_embd, 3 * config.n_embd) + self.value = nn.Linear(config.n_embd, 3 * config.n_embd) + self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + + def forward(self, x): + k = self.key(x) + v = self.value(x) + y = self.weight(F.gelu(k) * v) + return y + +######################################################################################################## +# The GPT Model with our blocks +######################################################################################################## + +class LabelSmoothingCrossEntropy(nn.Module): # might be able to avoid nan loss + def __init__(self, smoothing=0.0): + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, pred, target): + pred = pred.log_softmax(dim=-1) + with torch.no_grad(): + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (pred.size(-1) - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + return torch.mean(torch.sum(-true_dist * pred, dim=-1)) + +class GPTConfig: + def __init__(self, vocab_size, ctx_size, **kwargs): + self.vocab_size = vocab_size + self.ctx_size = ctx_size + for k,v in kwargs.items(): + setattr(self, k, v) + +class Block(nn.Module): + def __init__(self, config): + super().__init__() + + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + + if config.model_type == 'RWKV': + self.attn = RWKV_TimeMix(config) + self.mlp = RWKV_ChannelMix(config) + else: + self.attn = RotaryMHA(config) + self.mlp = GeGLU(config) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + self.ctx_size = config.ctx_size + self.apply(self._init_weights) + + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_ctx_size(self): + return self.ctx_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.01) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def configure_optimizers(self, train_config): + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + + whitelist_weight_modules = (nn.Linear, ) + blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn): + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + return optimizer + + def forward(self, idx, targets=None): + B, T = idx.size() + assert T <= self.ctx_size, "Cannot forward, model block size is exhausted." + + x = self.tok_emb(idx) + + x = self.blocks(x) + + x = self.ln_f(x) + logits = self.head(x) + + loss = None + if targets is not None: + loss = LabelSmoothingCrossEntropy(smoothing=1e-6)(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000..e9618f4 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,128 @@ +import math +import logging +import numpy as np +from tqdm.auto import tqdm +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data.dataloader import DataLoader +logger = logging.getLogger(__name__) + +class TrainerConfig: + max_epochs = 10 + batch_size = 64 + learning_rate = 3e-4 + betas = (0.9, 0.95) + grad_norm_clip = 1.0 + weight_decay = 0.01 + lr_decay = False # learning rate decay params: linear warmup followed by cosine decay + warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere + final_tokens = 260e9 # (at what point we reach 10% of original LR) + ckpt_path = None + num_workers = 0 # for DataLoader + + def __init__(self, **kwargs): + for k,v in kwargs.items(): + setattr(self, k, v) + +class Trainer: + + def __init__(self, model, train_dataset, test_dataset, config): + self.model = model + self.train_dataset = train_dataset + self.test_dataset = test_dataset + self.config = config + self.avg_loss = -1 + + # take over whatever gpus are on the system + self.device = 'cpu' + if torch.cuda.is_available(): + self.device = torch.cuda.current_device() + self.model = torch.nn.DataParallel(self.model).to(self.device) + + def save_checkpoint(self): + # DataParallel wrappers keep raw model object in .module attribute + raw_model = self.model.module if hasattr(self.model, "module") else self.model + logger.info("saving %s", self.config.ckpt_path) + torch.save(raw_model.state_dict(), self.config.ckpt_path) + + def train(self): + model, config = self.model, self.config + raw_model = model.module if hasattr(self.model, "module") else model + optimizer = raw_model.configure_optimizers(config) + + def run_epoch(split): + is_train = split == 'train' + model.train(is_train) + data = self.train_dataset if is_train else self.test_dataset + loader = DataLoader(data, shuffle=True, pin_memory=True, + batch_size=config.batch_size, + num_workers=config.num_workers) + + losses = [] + pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) + for it, (x, y) in pbar: + + # place data on the correct device + x = x.to(self.device) + y = y.to(self.device) + + # forward the model + with torch.set_grad_enabled(is_train): + logits, loss = model(x, y) + loss = loss.mean() # collapse all losses if they are scattered on multiple gpus + losses.append(loss.item()) + + if is_train: + + # backprop and update the parameters + model.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) + optimizer.step() + + # decay the learning rate based on our progress + if config.lr_decay: + self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) + if self.tokens < config.warmup_tokens: + # linear warmup + lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) + progress = 0 + else: + # cosine learning rate decay + progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + lr_final_factor = config.lr_final / config.learning_rate + lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 + lr = config.learning_rate * lr_mult + for param_group in optimizer.param_groups: + param_group['lr'] = lr + else: + lr = config.learning_rate + + # report progress + now_loss = loss.item() + if self.avg_loss < 0: + self.avg_loss = now_loss + else: + factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1)) + self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor + pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") + + if not is_train: + test_loss = float(np.mean(losses)) + logger.info("test loss: %f", test_loss) + return test_loss + + best_loss = float('inf') + self.tokens = 0 # counter used for learning rate decay + for epoch in range(config.max_epochs): + + run_epoch('train') + if self.test_dataset is not None: + test_loss = run_epoch('test') + + # supports early stopping based on the test loss, or just save always if no test set is provided + good_model = self.test_dataset is None or test_loss < best_loss + if self.config.ckpt_path is not None and good_model: + best_loss = test_loss + self.save_checkpoint() diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..6192589 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,46 @@ +import random +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +def top_k_logits(logits, k): + v, ix = torch.topk(logits, k) + out = logits.clone() + out[out < v[:, [-1]]] = -float('Inf') + return out + +def top_p_probs(probs, p): + out = probs.clone() + + sorted_probs, sorted_indices = torch.sort(out, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumulative_probs > p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + out[indices_to_remove] = 0 + + return out + +# top-p + top-k + pow&ratio sampling +def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None): + logits = logits[:, pos, :] / temperature + probs = F.softmax(logits, dim=-1) + if min_p_ratio is not None: + limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio + logits[probs < limit] = -float('Inf') + if top_k is not None: + logits = top_k_logits(logits, top_k) + probs = F.softmax(logits, dim=-1) + if top_p is not None: + probs[0] = top_p_probs(probs[0], top_p) + ix = torch.multinomial(probs, num_samples=1) + + return ix[0][0].cpu() + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/train.py b/train.py new file mode 100644 index 0000000..82e5802 --- /dev/null +++ b/train.py @@ -0,0 +1,117 @@ +import os, sys, time, math, random, json, datetime +import logging +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data import Dataset +from src.trainer import Trainer, TrainerConfig +from src.model import GPT, GPTConfig +from src.utils import set_seed + +set_seed(42) +np.set_printoptions(precision=4, suppress=True, linewidth=200) +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,) + +model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' + +datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip +model_level = 'character' # 'character' or 'word' + +ctx_size = 256 if 'character' else 128 +nLayers = 5 +nHead = 8 +nEmb = 512 + +nepoch = 50 +nbatchsz = 64 +epoch_length_fixed = 10000 # make an epoch very short, so we can see the training progress + +######################################################################################################## + +print("loading data...", end="") + +class Dataset(Dataset): + def __init__(self, data, model_level, ctx_size): + if model_level == 'word': + data = data.replace('\n', ' \n ').replace(' ', ' ').split(' ') + + unique = sorted(list(set(data))) + data_size, vocab_size = len(data), len(unique) + self.stoi = { ch:i for i,ch in enumerate(unique) } + self.itos = { i:ch for i,ch in enumerate(unique) } + print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) + self.ctx_size = ctx_size + self.vocab_size = vocab_size + self.data = data + + def __len__(self): + return epoch_length_fixed + + def __getitem__(self, idx): + i = np.random.randint(0, len(self.data) - (self.ctx_size + 1)) # CHEAT: pick a spot in the dataset at random + chunk = self.data[i:i+self.ctx_size+1] + dix = [self.stoi[s] for s in chunk] + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + return x, y + +train_dataset = Dataset(open(datafile, "r", encoding="utf-8").read(), model_level, ctx_size) + +######################################################################################################## + +model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type, + n_layer=nLayers, n_head=nHead, n_embd=nEmb)) + +print('model', model_type, 'total epoch', nepoch, 'batchsz', nbatchsz, 'nLayers', nLayers, 'nHead', nHead, 'nEmb', nEmb, 'len', ctx_size) +tconf = TrainerConfig(model_type=model_type, max_epochs=nepoch, batch_size=nbatchsz, + learning_rate=6e-4 if model_type == 'RWKV' else 4e-4, betas=(0.9, 0.99), # RWKV can use higher LR + lr_decay=True, lr_final=2e-4, warmup_tokens=0, final_tokens=nepoch*len(train_dataset)*ctx_size, num_workers=0) +trainer = Trainer(model, train_dataset, None, tconf) + +trainer.train() + +torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') + +######################################################################################################## + +from src.utils import sample_logits + +MAX_LEN = ctx_size +NUM_OF_RUNS = 5 +LENGTH_OF_EACH = 300 + +for run in range(NUM_OF_RUNS): + context = "It was" + + x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) + + real_len = len(x) + if real_len < MAX_LEN: + x = np.pad(x, (0, MAX_LEN - real_len)) + print_begin = 0 + + for i in range(LENGTH_OF_EACH): + + if i == 0: + print(('-' * 80) + '\n' + context, end = '') + print_begin = real_len + + with torch.no_grad(): + xxx = torch.tensor(x[-MAX_LEN:], dtype=torch.long)[None,...].to("cuda:0") + out, _ = model(xxx) + pos = -1 if real_len >= MAX_LEN else real_len - 1 + + char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) + + if real_len < MAX_LEN: + x[real_len] = char + else: + x = np.append(x, char) + real_len += 1 + + if i % 10 == 9 or i == LENGTH_OF_EACH-1: + completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) + print(completion, end = '') + print_begin = real_len + print()