File size: 3,881 Bytes
7d6d701
caeaee0
7d6d701
99bbf81
a627434
f166d62
8730199
004cf23
eb978fe
caeaee0
 
7ddcfd9
eb978fe
4ba26b3
eb978fe
d6396cd
 
 
 
7ddcfd9
83e17fa
 
 
 
 
7ddcfd9
08b6d98
8ded9c8
d6396cd
f6fcf7f
917e125
4d74cbf
044c0a3
4d74cbf
044c0a3
4d74cbf
20c453c
7ddcfd9
caeaee0
 
 
 
 
 
 
 
 
 
1e517cc
caeaee0
 
 
 
99bbf81
caeaee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a256c
caeaee0
 
 
7d6d701
 
7ddcfd9
4873e9b
 
 
b8c92ce
7b3e7b6
de6b55b
4873e9b
 
b8c92ce
 
 
 
 
e99e5be
 
7ddcfd9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import logging, os, sys, threading, time

from dotenv import load_dotenv, find_dotenv

from rag_langchain import LangChainRAG
from rag_llamaindex import LlamaIndexRAG
from trace import trace_wandb

lock = threading.Lock()

_ = load_dotenv(find_dotenv())

RAG_INGESTION = False # load, split, embed, and store documents

RAG_OFF        = "Off"
RAG_LANGCHAIN  = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"

config = {
    "chunk_overlap": 100,       # split documents
    "chunk_size": 2000,         # split documents
    "k": 2,                     # retrieve documents
    "model_name": "gpt-4-0314", # llm
    "temperature": 0            # llm
}

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))

def invoke(openai_api_key, prompt, rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        os.environ["OPENAI_API_KEY"] = openai_api_key
        
        if (RAG_INGESTION):
            if (rag_option == RAG_LANGCHAIN):
                rag = LangChainRAG()
                rag.ingestion(config)
            elif (rag_option == RAG_LLAMAINDEX):
                rag = LlamaIndexRAG()
                rag.ingestion(config)
    
        completion = ""
        result = ""
        callback = ""
        err_msg = ""
        
        try:
            start_time_ms = round(time.time() * 1000)
    
            if (rag_option == RAG_LANGCHAIN):
                rag = LangChainRAG()
                completion, callback = rag.rag_chain(config, prompt)
                result = completion["result"]
            elif (rag_option == RAG_LLAMAINDEX):
                rag = LlamaIndexRAG()
                result, callback = rag.retrieval(config, prompt)
            else:
                rag = LangChainRAG()
                completion, callback = rag.llm_chain(config, prompt)
                result = completion.generations[0][0].text
        except Exception as e:
            err_msg = e
    
            raise gr.Error(e)
        finally:
            end_time_ms = round(time.time() * 1000)
            
            trace_wandb(
                config,
                rag_option,
                prompt, 
                completion, 
                result, 
                callback, 
                err_msg, 
                start_time_ms, 
                end_time_ms
            )

            del os.environ["OPENAI_API_KEY"]
    
        return result

gr.close_all()

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = "List GPT-4's exam scores and benchmark results.", lines = 1),
              gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
    outputs = [gr.Textbox(label = "Completion")],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"],
    examples = [["sk-<BringYourOwn>", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LLAMAINDEX],
                ["sk-<BringYourOwn>", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LLAMAINDEX],
                ["sk-<BringYourOwn>", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LLAMAINDEX]],
               cache_examples = False
)

demo.launch()