|
|
|
|
@ -8,8 +8,19 @@ So it's combining the best of RNN and transformer - **great performance, fast in
|
|
|
|
|
|
|
|
|
|
**Download RWKV-4 0.1/0.4/1.5/3/7/14B weights**: https://huggingface.co/BlinkDL
|
|
|
|
|
|
|
|
|
|
**RWKV chatbot**: https://github.com/BlinkDL/ChatRWKV
|
|
|
|
|
|
|
|
|
|
**ChatRWKV v2:** with "stream" and "split" strategies. **3G VRAM is enough to run RWKV 14B :)** https://github.com/BlinkDL/ChatRWKV/tree/main/v2
|
|
|
|
|
```
|
|
|
|
|
os.environ["RWKV_JIT_ON"] = '1'
|
|
|
|
|
from rwkv.model import RWKV # everything in /v2/rwkv folder
|
|
|
|
|
model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16')
|
|
|
|
|
|
|
|
|
|
out, state = model.forward([187, 510, 1563, 310, 247], None) # use 20B_tokenizer.json
|
|
|
|
|
print(out.detach().cpu().numpy()) # get logits
|
|
|
|
|
out, state = model.forward([187, 510], None)
|
|
|
|
|
out, state = model.forward([1563], state) # RNN has state (use deepcopy if you want to clone it)
|
|
|
|
|
out, state = model.forward([310, 247], state)
|
|
|
|
|
print(out.detach().cpu().numpy()) # same result as above
|
|
|
|
|
```
|
|
|
|
|
**HF space**: https://huggingface.co/spaces/yahma/rwkv-14b
|
|
|
|
|
|
|
|
|
|

|
|
|
|
|
|