File size: 3,881 Bytes
7d6d701 caeaee0 7d6d701 99bbf81 a627434 f166d62 8730199 004cf23 eb978fe caeaee0 7ddcfd9 eb978fe 4ba26b3 eb978fe d6396cd 7ddcfd9 83e17fa 7ddcfd9 08b6d98 8ded9c8 d6396cd f6fcf7f 917e125 4d74cbf 044c0a3 4d74cbf 044c0a3 4d74cbf 20c453c 7ddcfd9 caeaee0 1e517cc caeaee0 99bbf81 caeaee0 44a256c caeaee0 7d6d701 7ddcfd9 4873e9b b8c92ce 7b3e7b6 de6b55b 4873e9b b8c92ce e99e5be 7ddcfd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import logging, os, sys, threading, time
from dotenv import load_dotenv, find_dotenv
from rag_langchain import LangChainRAG
from rag_llamaindex import LlamaIndexRAG
from trace import trace_wandb
lock = threading.Lock()
_ = load_dotenv(find_dotenv())
RAG_INGESTION = False # load, split, embed, and store documents
RAG_OFF = "Off"
RAG_LANGCHAIN = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"
config = {
"chunk_overlap": 100, # split documents
"chunk_size": 2000, # split documents
"k": 2, # retrieve documents
"model_name": "gpt-4-0314", # llm
"temperature": 0 # llm
}
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
def invoke(openai_api_key, prompt, rag_option):
if not openai_api_key:
raise gr.Error("OpenAI API Key is required.")
if not prompt:
raise gr.Error("Prompt is required.")
if not rag_option:
raise gr.Error("Retrieval-Augmented Generation is required.")
with lock:
os.environ["OPENAI_API_KEY"] = openai_api_key
if (RAG_INGESTION):
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
rag.ingestion(config)
elif (rag_option == RAG_LLAMAINDEX):
rag = LlamaIndexRAG()
rag.ingestion(config)
completion = ""
result = ""
callback = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
completion, callback = rag.rag_chain(config, prompt)
result = completion["result"]
elif (rag_option == RAG_LLAMAINDEX):
rag = LlamaIndexRAG()
result, callback = rag.retrieval(config, prompt)
else:
rag = LangChainRAG()
completion, callback = rag.llm_chain(config, prompt)
result = completion.generations[0][0].text
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
trace_wandb(
config,
rag_option,
prompt,
completion,
result,
callback,
err_msg,
start_time_ms,
end_time_ms
)
del os.environ["OPENAI_API_KEY"]
return result
gr.close_all()
demo = gr.Interface(
fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = "List GPT-4's exam scores and benchmark results.", lines = 1),
gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
outputs = [gr.Textbox(label = "Completion")],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"],
examples = [["sk-<BringYourOwn>", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LLAMAINDEX],
["sk-<BringYourOwn>", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
["sk-<BringYourOwn>", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LLAMAINDEX],
["sk-<BringYourOwn>", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
["sk-<BringYourOwn>", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LLAMAINDEX]],
cache_examples = False
)
demo.launch() |