gschoeni's picture
Create app.py
3bf779c verified
import gradio as gr
import os
import oxen
from oxen import RemoteRepo
from oxen.auth import config_auth
from oxen.user import config_user, current_user
import json
import argparse
# These will be initialized in the main function
repo = None
def download_prompts(branch):
'''
Downloads prompts.jsonl from the main remote repository
'''
# Download takes a file or directory a commit id
# write stuff for when prompts.jsonl doesnt exist
repo.download("prompts.jsonl", revision=branch)
return "Successfully downloaded prompts.jsonl from the main remote repository."
# Function to update the bad_prompt variable
def add_prompt(api_key, prompt):
'''
saves prompt to prompts.json
'''
# write lock file
lock_file = "LOCK"
if os.path.exists(lock_file):
return "Someone else is currently adding prompts. Please try again later."
else:
with open(lock_file, 'w') as f:
f.write("locked")
if api_key == "":
return "Please enter an API key"
# this will not work in high concurrent environments, but for now it's fine
try:
config_auth(api_key)
except:
return "Invalid API Key"
user = current_user()
# convert name into a branch name with dashes
username = user.name.replace(" ", "-")
branch_name = f"add-prompt-{username}"
# create branch if doesn't exist, switch to it if it does
try:
repo.create_checkout_branch(branch_name)
except:
return "You do not have permission to create branches. Please contact @greg.schoeninger in the Oxen.ai discord for access. https://discord.com/invite/s3tBEn7Ptg"
# Download the last set of prompts
download_prompts(branch_name)
file_path = "prompts.jsonl"
data = set()
with open(file_path) as f:
for line in f:
print(line)
data.add(line.strip())
data.add(json.dumps({"prompt": prompt, "user": user.name}))
with open(file_path, 'w') as writer:
for line in data:
writer.write(line)
writer.write('\n')
repo.add('prompts.jsonl')
repo.commit(f"Adding '{prompt}' to dataset")
# remove lock file
os.remove(lock_file)
return "Prompt added to prompts.jsonl"
def is_configured():
return oxen.is_configured() and 'OXEN_USER_NAME' in os.environ and 'OXEN_USER_EMAIL' in os.environ and 'OXEN_AUTH_TOKEN' in os.environ
def print_error_not_configured():
print("Please configure your Oxen credentials:\n\n\toxen config --auth API_KEY\n\toxen config --name YOUR_NAME --email YOUR_EMAIL\n\nor set the OXEN_AUTH_TOKEN, OXEN_USER_NAME and OXEN_USER_EMAIL environment variables.")
def main():
global repo
repo_name = "datasets/UnanswerableQuestions"
repo = RemoteRepo(repo_name)
if not repo.exists():
print(f"Repository {repo_name} does not exist")
print(f"Saving prompts to {repo_name}")
with gr.Blocks() as demo:
gr.Markdown("# Unanswerable Questions")
gr.Markdown("This is a dataset of unanswerable questions to test whether an LLM \"knows when it does not know\" and minimize hallucinations. Click Save to Oxen to save the prompts to Oxen.ai")
gr.Markdown("To view the dataset, visit https://oxen.ai/datasets/UnanswerableQuestions")
gr.Markdown("## Save prompt to Oxen")
success_message = gr.Markdown()
api_key = gr.Textbox(label="Enter Oxen.ai API Key here", lines=1, type="password")
prompt = gr.Textbox(label="Enter prompt here", lines=2)
button1 = gr.Button(value="Save Prompt")
button1.click(add_prompt, inputs = [api_key, prompt], outputs=success_message)
demo.launch()
demo.close()
if __name__ == "__main__":
main()