File size: 1,439 Bytes
ce11ffc
 
 
f20ab91
119215e
 
 
2e1c80b
119215e
 
 
 
 
9c931ea
119215e
 
 
 
 
 
 
 
 
 
 
ce11ffc
 
119215e
 
 
 
 
 
 
 
 
ce11ffc
 
119215e
 
ae0e08b
 
ce11ffc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import random

from dataset_viber import AnnotatorInterFace
from datasets import load_dataset
from huggingface_hub import InferenceClient

# https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
MODEL_IDS = [
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "microsoft/Phi-3-mini-4k-instruct",
    "mistralai/Mistral-7B-Instruct-v0.2"
]
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]

dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train")

def _get_response(messages):
    client = random.choice(CLIENTS)
    message = client.chat_completion(
        messages=messages,
        stream=False,
        max_tokens=2000
    )
    return message.choices[0].message.content

def next_input(_prompt, _completion_a, _completion_b):
    new_dataset = dataset.shuffle()
    row = new_dataset[0]
    messages = row["messages"][:-1]
    completions = [row["response"]]
    completions.append(_get_response(messages))
    completions.append(_get_response(messages))
    random.shuffle(completions)
    return messages, completions.pop(), completions.pop()


if __name__ == "__main__":
    interface = AnnotatorInterFace.for_chat_generation_preference(
        fn_next_input=next_input,
        interactive=[False, True, True],
        dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle",
    )
    interface.launch()