|
import os |
|
import random |
|
|
|
from dataset_viber import AnnotatorInterFace |
|
from datasets import load_dataset |
|
from huggingface_hub import InferenceClient |
|
import time |
|
|
|
|
|
MODEL_IDS = [ |
|
"microsoft/Phi-3-mini-4k-instruct" |
|
] |
|
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS] |
|
|
|
dataset = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train") |
|
|
|
|
|
def get_response(messages): |
|
max_retries = 3 |
|
retry_delay = 3 |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
client = random.choice(CLIENTS) |
|
message = client.chat_completion( |
|
messages=messages, |
|
stream=False, |
|
max_tokens=2000 |
|
) |
|
return message.choices[0].message.content |
|
except Exception as e: |
|
if attempt < max_retries - 1: |
|
print(f"An error occurred: {e}. Retrying in {retry_delay} seconds...") |
|
time.sleep(retry_delay) |
|
else: |
|
print(f"Max retries reached. Last error: {e}") |
|
raise |
|
|
|
return None |
|
|
|
def next_input(_prompt, _completion_a, _completion_b): |
|
new_dataset = dataset.shuffle() |
|
row = new_dataset[0] |
|
messages = row["chosen"][:-1] |
|
completions = [row["chosen"][-1]["content"]] |
|
completions.append(get_response(messages)) |
|
random.shuffle(completions) |
|
return messages, completions.pop(), completions.pop() |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = AnnotatorInterFace.for_chat_generation_preference( |
|
fn_next_input=next_input, |
|
interactive=[False, True, True], |
|
dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle", |
|
) |
|
interface.launch() |
|
|