Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import oxen | |
from oxen import RemoteRepo | |
from oxen.auth import config_auth | |
from oxen.user import config_user, current_user | |
import json | |
import argparse | |
# These will be initialized in the main function | |
repo = None | |
def download_prompts(branch): | |
''' | |
Downloads prompts.jsonl from the main remote repository | |
''' | |
# Download takes a file or directory a commit id | |
# write stuff for when prompts.jsonl doesnt exist | |
repo.download("prompts.jsonl", revision=branch) | |
return "Successfully downloaded prompts.jsonl from the main remote repository." | |
# Function to update the bad_prompt variable | |
def add_prompt(api_key, prompt): | |
''' | |
saves prompt to prompts.json | |
''' | |
# write lock file | |
lock_file = "LOCK" | |
if os.path.exists(lock_file): | |
return "Someone else is currently adding prompts. Please try again later." | |
else: | |
with open(lock_file, 'w') as f: | |
f.write("locked") | |
if api_key == "": | |
return "Please enter an API key" | |
# this will not work in high concurrent environments, but for now it's fine | |
try: | |
config_auth(api_key) | |
except: | |
return "Invalid API Key" | |
user = current_user() | |
# convert name into a branch name with dashes | |
username = user.name.replace(" ", "-") | |
branch_name = f"add-prompt-{username}" | |
# create branch if doesn't exist, switch to it if it does | |
try: | |
repo.create_checkout_branch(branch_name) | |
except: | |
return "You do not have permission to create branches. Please contact @greg.schoeninger in the Oxen.ai discord for access. https://discord.com/invite/s3tBEn7Ptg" | |
# Download the last set of prompts | |
download_prompts(branch_name) | |
file_path = "prompts.jsonl" | |
data = set() | |
with open(file_path) as f: | |
for line in f: | |
print(line) | |
data.add(line.strip()) | |
data.add(json.dumps({"prompt": prompt, "user": user.name})) | |
with open(file_path, 'w') as writer: | |
for line in data: | |
writer.write(line) | |
writer.write('\n') | |
repo.add('prompts.jsonl') | |
repo.commit(f"Adding '{prompt}' to dataset") | |
# remove lock file | |
os.remove(lock_file) | |
return "Prompt added to prompts.jsonl" | |
def is_configured(): | |
return oxen.is_configured() and 'OXEN_USER_NAME' in os.environ and 'OXEN_USER_EMAIL' in os.environ and 'OXEN_AUTH_TOKEN' in os.environ | |
def print_error_not_configured(): | |
print("Please configure your Oxen credentials:\n\n\toxen config --auth API_KEY\n\toxen config --name YOUR_NAME --email YOUR_EMAIL\n\nor set the OXEN_AUTH_TOKEN, OXEN_USER_NAME and OXEN_USER_EMAIL environment variables.") | |
def main(): | |
global repo | |
repo_name = "datasets/UnanswerableQuestions" | |
repo = RemoteRepo(repo_name) | |
if not repo.exists(): | |
print(f"Repository {repo_name} does not exist") | |
print(f"Saving prompts to {repo_name}") | |
with gr.Blocks() as demo: | |
gr.Markdown("# Unanswerable Questions") | |
gr.Markdown("This is a dataset of unanswerable questions to test whether an LLM \"knows when it does not know\" and minimize hallucinations. Click Save to Oxen to save the prompts to Oxen.ai") | |
gr.Markdown("To view the dataset, visit https://oxen.ai/datasets/UnanswerableQuestions") | |
gr.Markdown("## Save prompt to Oxen") | |
success_message = gr.Markdown() | |
api_key = gr.Textbox(label="Enter Oxen.ai API Key here", lines=1, type="password") | |
prompt = gr.Textbox(label="Enter prompt here", lines=2) | |
button1 = gr.Button(value="Save Prompt") | |
button1.click(add_prompt, inputs = [api_key, prompt], outputs=success_message) | |
demo.launch() | |
demo.close() | |
if __name__ == "__main__": | |
main() |