Spaces:
Runtime error
Runtime error
Commit
·
818b0c6
1
Parent(s):
cd5a657
Update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
21 |
ctx_limit = 3500
|
22 |
########################## text rwkv ################################################################
|
23 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
24 |
-
from rwkv.model import RWKV
|
25 |
title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
|
26 |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
|
27 |
model = RWKV(model=model_path, strategy='cuda fp16')
|
@@ -60,7 +59,8 @@ def evaluate(
|
|
60 |
occurrence = {}
|
61 |
state = None
|
62 |
for i in range(int(token_count)):
|
63 |
-
|
|
|
64 |
for n in occurrence:
|
65 |
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
66 |
|
|
|
21 |
ctx_limit = 3500
|
22 |
########################## text rwkv ################################################################
|
23 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
|
|
24 |
title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
|
25 |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
|
26 |
model = RWKV(model=model_path, strategy='cuda fp16')
|
|
|
59 |
occurrence = {}
|
60 |
state = None
|
61 |
for i in range(int(token_count)):
|
62 |
+
input_ids = pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token]
|
63 |
+
out, state = model.forward(tokens=input_ids, state=state)
|
64 |
for n in occurrence:
|
65 |
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
66 |
|