Commit
·
119215e
1
Parent(s):
f20ab91
Add overview of magpie battle
Browse files- .gitignore +1 -0
- README.md +2 -2
- app.py +31 -43
- requirements.txt +1 -1
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.env
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: red
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
+
title: Dataset Viber - chat preference magpie battle
|
3 |
+
emoji: ⚔️
|
4 |
colorFrom: red
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -1,54 +1,42 @@
|
|
1 |
import os
|
2 |
-
import io
|
3 |
import random
|
4 |
|
5 |
-
import requests
|
6 |
-
from PIL import Image
|
7 |
from dataset_viber import AnnotatorInterFace
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
"
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}"
|
29 |
-
response = requests.get(api_url, headers=HEADERS)
|
30 |
-
num_rows = response.json()["size"]["config"]["num_rows"]
|
31 |
-
return num_rows
|
32 |
-
|
33 |
-
|
34 |
-
def generate_response(prompt):
|
35 |
-
payload = {
|
36 |
-
"inputs": prompt,
|
37 |
-
}
|
38 |
-
response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
|
39 |
-
image = Image.open(io.BytesIO(response.content))
|
40 |
-
return image
|
41 |
-
|
42 |
|
43 |
def next_input(_prompt, _completion_a, _completion_b):
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
-
interface = AnnotatorInterFace.
|
51 |
-
|
52 |
dataset_name=None,
|
53 |
)
|
54 |
interface.launch()
|
|
|
1 |
import os
|
|
|
2 |
import random
|
3 |
|
|
|
|
|
4 |
from dataset_viber import AnnotatorInterFace
|
5 |
+
from datasets import load_dataset
|
6 |
+
from huggingface_hub import InferenceClient
|
7 |
+
|
8 |
+
MODEL_IDS = [
|
9 |
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
10 |
+
"microsoft/Phi-3-mini-4k-instruct",
|
11 |
+
"mistralai/Mistral-7B-Instruct-v0.2"
|
12 |
+
]
|
13 |
+
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_AUTH_TOKEN_PERSONAL"]) for model_id in MODEL_IDS]
|
14 |
+
|
15 |
+
dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train")
|
16 |
+
|
17 |
+
def _get_response(messages):
|
18 |
+
client = random.choice(CLIENTS)
|
19 |
+
message = client.chat_completion(
|
20 |
+
messages=messages,
|
21 |
+
stream=False,
|
22 |
+
max_tokens=2000
|
23 |
+
)
|
24 |
+
return message.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def next_input(_prompt, _completion_a, _completion_b):
|
27 |
+
new_dataset = dataset.shuffle()
|
28 |
+
row = new_dataset[0]
|
29 |
+
messages = row["messages"][:-1]
|
30 |
+
completions = [row["response"]]
|
31 |
+
completions.append(_get_response(messages))
|
32 |
+
completions.append(_get_response(messages))
|
33 |
+
random.shuffle(completions)
|
34 |
+
return messages, completions.pop(), completions.pop()
|
35 |
+
|
36 |
|
37 |
if __name__ == "__main__":
|
38 |
+
interface = AnnotatorInterFace.for_chat_generation_preference(
|
39 |
+
fn_next_input=next_input,
|
40 |
dataset_name=None,
|
41 |
)
|
42 |
interface.launch()
|
requirements.txt
CHANGED
@@ -1 +1 @@
|
|
1 |
-
|
|
|
1 |
+
dataset-viber==0.2.1
|