File size: 3,272 Bytes
7d6d701 7ddcfd9 7d6d701 99bbf81 a627434 44a256c 004cf23 eb978fe 7ddcfd9 eb978fe 44a256c eb978fe 7ddcfd9 3ddc880 51605e1 5b9fc25 7ddcfd9 08b6d98 f6fcf7f eb1e69b f6fcf7f 917e125 ebcdcac 044c0a3 ebcdcac 044c0a3 917e125 7ddcfd9 7c151aa 44a256c 1e517cc acf522c 26b6a5b ddfaa69 c6bc22c 12d440a 1e517cc 1283168 9102fcd 99bbf81 7ddcfd9 cd4364b 99bbf81 ddfaa69 7ddcfd9 cd4364b 09018fc 7ddcfd9 1283168 12d440a 7ddcfd9 c2e6078 37ab520 043b829 99bbf81 004cf23 7ddcfd9 9294813 7ddcfd9 44a256c 8d60a3f 7d6d701 7ddcfd9 1f3b512 14e92f6 99bbf81 14e92f6 b7d5b27 c42c62b 4c13268 621d4dd 4c13268 7ddcfd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import os, time
from dotenv import load_dotenv, find_dotenv
from rag import llm_chain, rag_chain, rag_ingestion
from trace import trace_wandb
_ = load_dotenv(find_dotenv())
RAG_INGESTION = False # load, split, embed, and store documents
config = {
"k": 3, # retrieve documents
"model_name": "gpt-4-0314", # llm
"temperature": 0 # llm
}
RAG_OFF = "Off"
RAG_MONGODB = "MongoDB" # serverless
RAG_CHROMA = "Chroma" # requires persistent storage (small is $0.01/hour)
def invoke(openai_api_key, prompt, rag_option):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
os.environ["OPENAI_API_KEY"] = openai_api_key
if (RAG_INGESTION):
rag_ingestion(config)
chain = None
completion = ""
result = ""
cb = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_OFF):
completion, chain, cb = llm_chain(config, prompt)
if (completion.generations[0] != None and completion.generations[0][0] != None):
result = completion.generations[0][0].text
else:
completion, chain, cb = rag_chain(config, rag_option, prompt)
result = completion["result"]
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
trace_wandb(config,
rag_option == RAG_OFF,
prompt,
completion,
result,
chain,
cb,
err_msg,
start_time_ms,
end_time_ms)
return result
gr.close_all()
demo = gr.Interface(fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
gr.Radio([RAG_OFF, RAG_MONGODB], label = "Retrieval-Augmented Generation", value = RAG_MONGODB)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"],
examples = [["", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_MONGODB],
["", "List GPT-4's exam scores and benchmark results.", RAG_MONGODB],
["", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_MONGODB],
["", "Write a Python program that calls the GPT-4 API.", RAG_MONGODB],
["", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_MONGODB]],
cache_examples = False)
demo.launch() |