File size: 4,056 Bytes
7d6d701 caeaee0 7d6d701 99bbf81 a627434 f166d62 8b16668 004cf23 eb978fe caeaee0 7ddcfd9 eb978fe 4ba26b3 eb978fe d6396cd 7ddcfd9 83e17fa 7ddcfd9 08b6d98 8ded9c8 d6396cd f6fcf7f 917e125 4d74cbf 044c0a3 4d74cbf 044c0a3 4d74cbf 20c453c 7ddcfd9 caeaee0 8b16668 1e517cc caeaee0 99bbf81 caeaee0 8b16668 caeaee0 44a256c caeaee0 70fec3e caeaee0 7d6d701 7ddcfd9 4873e9b 70fec3e 7b3e7b6 70fec3e 4873e9b 4587e33 32147ab 4587e33 32147ab 4587e33 e99e5be 7ddcfd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import logging, os, sys, threading, time
from dotenv import load_dotenv, find_dotenv
from rag_langchain import LangChainRAG
#from rag_llamaindex import LlamaIndexRAG
from trace import trace_wandb
lock = threading.Lock()
_ = load_dotenv(find_dotenv())
RAG_INGESTION = False # load, split, embed, and store documents
RAG_OFF = "Off"
RAG_LANGCHAIN = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"
config = {
"chunk_overlap": 100, # split documents
"chunk_size": 2000, # split documents
"k": 2, # retrieve documents
"model_name": "gpt-4-0314", # llm
"temperature": 0 # llm
}
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
def invoke(openai_api_key, prompt, rag_option):
if not openai_api_key:
raise gr.Error("OpenAI API Key is required.")
if not prompt:
raise gr.Error("Prompt is required.")
if not rag_option:
raise gr.Error("Retrieval-Augmented Generation is required.")
with lock:
os.environ["OPENAI_API_KEY"] = openai_api_key
if (RAG_INGESTION):
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
rag.ingestion(config)
#elif (rag_option == RAG_LLAMAINDEX):
# rag = LlamaIndexRAG()
# rag.ingestion(config)
completion = ""
result = ""
callback = ""
err_msg = ""
try:
start_time_ms = round(time.time() * 1000)
if (rag_option == RAG_LANGCHAIN):
rag = LangChainRAG()
completion, callback = rag.rag_chain(config, prompt)
result = completion["result"]
#elif (rag_option == RAG_LLAMAINDEX):
# rag = LlamaIndexRAG()
# result, callback = rag.retrieval(config, prompt)
else:
rag = LangChainRAG()
completion, callback = rag.llm_chain(config, prompt)
result = completion.generations[0][0].text
except Exception as e:
err_msg = e
raise gr.Error(e)
finally:
end_time_ms = round(time.time() * 1000)
trace_wandb(
config,
rag_option,
prompt,
completion,
result,
callback,
err_msg,
start_time_ms,
end_time_ms
)
del os.environ["OPENAI_API_KEY"]
print("###")
print(result)
print("###")
return result
gr.close_all()
demo = gr.Interface(
fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", lines = 1),
gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
outputs = [gr.Textbox(label = "Completion", value = os.environ["COMPLETION"])],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"],
examples = [["sk-<BringYourOwn>", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LANGCHAIN],
["sk-<BringYourOwn>", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
["sk-<BringYourOwn>", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LANGCHAIN],
["sk-<BringYourOwn>", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
["sk-<BringYourOwn>", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LANGCHAIN]],
cache_examples = False
)
demo.launch() |