davidberenstein1957's picture
Update app.py
ae0e08b verified
raw
history blame
1.44 kB
import os
import random
from dataset_viber import AnnotatorInterFace
from datasets import load_dataset
from huggingface_hub import InferenceClient
# https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
MODEL_IDS = [
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"microsoft/Phi-3-mini-4k-instruct",
"mistralai/Mistral-7B-Instruct-v0.2"
]
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]
dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train")
def _get_response(messages):
client = random.choice(CLIENTS)
message = client.chat_completion(
messages=messages,
stream=False,
max_tokens=2000
)
return message.choices[0].message.content
def next_input(_prompt, _completion_a, _completion_b):
new_dataset = dataset.shuffle()
row = new_dataset[0]
messages = row["messages"][:-1]
completions = [row["response"]]
completions.append(_get_response(messages))
completions.append(_get_response(messages))
random.shuffle(completions)
return messages, completions.pop(), completions.pop()
if __name__ == "__main__":
interface = AnnotatorInterFace.for_chat_generation_preference(
fn_next_input=next_input,
interactive=[False, True, True],
dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle",
)
interface.launch()