File size: 3,544 Bytes
7d6d701 d6396cd 7d6d701 99bbf81 a627434 f166d62 8730199 004cf23 eb978fe 7ddcfd9 eb978fe 4ba26b3 eb978fe d6396cd 7ddcfd9 83e17fa 7ddcfd9 08b6d98 8ded9c8 d6396cd f6fcf7f 917e125 ebcdcac 044c0a3 ebcdcac 044c0a3 917e125 7ddcfd9 7c151aa 44a256c ce136c7 f166d62 ce136c7 8730199 ee1b26f 26b6a5b ddfaa69 ce136c7 12d440a 1e517cc 1283168 9102fcd 99bbf81 e0ddc02 f166d62 76f77bf e0ddc02 5e724ee 76f77bf e0ddc02 f166d62 76f77bf e497043 1283168 12d440a 7ddcfd9 c2e6078 37ab520 043b829 99bbf81 4873e9b ee1b26f 4873e9b e99e5be 44a256c 8d60a3f 7d6d701 7ddcfd9 4873e9b 98d47e1 e99e5be 7ddcfd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import logging, os, sys, time
from dotenv import load_dotenv, find_dotenv
from rag_langchain import LangChainRAG
from rag_llamaindex import LlamaIndexRAG
from trace import trace_wandb
_ = load_dotenv(find_dotenv())
RAG_INGESTION = False # load, split, embed, and store documents
RAG_OFF = "Off"
RAG_LANGCHAIN = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"
config = {
"chunk_overlap": 100, # split documents
"chunk_size": 2000, # split documents
"k": 2, # retrieve documents
"model_name": "gpt-4-0314", # llm
"temperature": 0 # llm
}
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
def invoke(openai_api_key, prompt, rag_option):
if (openai_api_key == ""):
raise gr.Error("OpenAI API Key is required.")
if (prompt == ""):
raise gr.Error("Prompt is required.")
if (rag_option is None):
raise gr.Error("Retrieval Augmented Generation is required.")
os.environ["OPENAI_API_KEY"] = openai_api_key
if (RAG_INGESTION):
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
rag.ingestion(config)
elif (rag_option == RAG_LLAMAINDEX):
rag = LlamaIndexRAG()
rag.ingestion(config)
completion = ""
result = ""
callback = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
completion, callback = rag.rag_chain(config, prompt)
result = completion["result"]
elif (rag_option == RAG_LLAMAINDEX):
rag = LlamaIndexRAG()
result, callback = rag.retrieval(config, prompt)
else:
rag = LangChainRAG()
completion, callback = rag.llm_chain(config, prompt)
result = completion.generations[0][0].text
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
trace_wandb(
config,
rag_option,
prompt,
completion,
result,
callback,
err_msg,
start_time_ms,
end_time_ms
)
return result
gr.close_all()
demo = gr.Interface(
fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"],
examples = [["", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LLAMAINDEX],
["", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
["", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LLAMAINDEX],
["", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
["", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LLAMAINDEX]],
cache_examples = False
)
demo.launch() |