File size: 1,439 Bytes
ce11ffc f20ab91 119215e 2e1c80b 119215e 9c931ea 119215e ce11ffc 119215e ce11ffc 119215e ae0e08b ce11ffc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import os
import random
from dataset_viber import AnnotatorInterFace
from datasets import load_dataset
from huggingface_hub import InferenceClient
# https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
MODEL_IDS = [
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"microsoft/Phi-3-mini-4k-instruct",
"mistralai/Mistral-7B-Instruct-v0.2"
]
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]
dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train")
def _get_response(messages):
client = random.choice(CLIENTS)
message = client.chat_completion(
messages=messages,
stream=False,
max_tokens=2000
)
return message.choices[0].message.content
def next_input(_prompt, _completion_a, _completion_b):
new_dataset = dataset.shuffle()
row = new_dataset[0]
messages = row["messages"][:-1]
completions = [row["response"]]
completions.append(_get_response(messages))
completions.append(_get_response(messages))
random.shuffle(completions)
return messages, completions.pop(), completions.pop()
if __name__ == "__main__":
interface = AnnotatorInterFace.for_chat_generation_preference(
fn_next_input=next_input,
interactive=[False, True, True],
dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle",
)
interface.launch()
|