davidberenstein1957's picture
Update app.py
53d7faf verified
import os
import random
from dataset_viber import AnnotatorInterFace
from datasets import load_dataset
from huggingface_hub import InferenceClient
import time
# https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
MODEL_IDS = [
"microsoft/Phi-3-mini-4k-instruct"
]
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]
dataset = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train")
def get_response(messages):
max_retries = 3
retry_delay = 3
for attempt in range(max_retries):
try:
client = random.choice(CLIENTS)
message = client.chat_completion(
messages=messages,
stream=False,
max_tokens=2000
)
return message.choices[0].message.content
except Exception as e:
if attempt < max_retries - 1:
print(f"An error occurred: {e}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print(f"Max retries reached. Last error: {e}")
raise
return None # This line will only be reached if all retries fail
def next_input(_prompt, _completion_a, _completion_b):
new_dataset = dataset.shuffle()
row = new_dataset[0]
messages = row["chosen"][:-1]
completions = [row["chosen"][-1]["content"]]
completions.append(get_response(messages))
random.shuffle(completions)
return messages, completions.pop(), completions.pop()
if __name__ == "__main__":
interface = AnnotatorInterFace.for_chat_generation_preference(
fn_next_input=next_input,
interactive=[False, True, True],
dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle",
)
interface.launch()