File size: 4,056 Bytes
7d6d701
caeaee0
7d6d701
99bbf81
a627434
f166d62
8b16668
004cf23
eb978fe
caeaee0
 
7ddcfd9
eb978fe
4ba26b3
eb978fe
d6396cd
 
 
 
7ddcfd9
83e17fa
 
 
 
 
7ddcfd9
08b6d98
8ded9c8
d6396cd
f6fcf7f
917e125
4d74cbf
044c0a3
4d74cbf
044c0a3
4d74cbf
20c453c
7ddcfd9
caeaee0
 
 
 
 
 
 
8b16668
 
 
1e517cc
caeaee0
 
 
 
99bbf81
caeaee0
 
 
 
 
 
 
8b16668
 
 
caeaee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a256c
caeaee0
70fec3e
 
 
 
 
caeaee0
7d6d701
 
7ddcfd9
4873e9b
 
 
70fec3e
7b3e7b6
70fec3e
4873e9b
 
4587e33
32147ab
4587e33
32147ab
4587e33
e99e5be
 
7ddcfd9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import logging, os, sys, threading, time

from dotenv import load_dotenv, find_dotenv

from rag_langchain import LangChainRAG
#from rag_llamaindex import LlamaIndexRAG
from trace import trace_wandb

lock = threading.Lock()

_ = load_dotenv(find_dotenv())

RAG_INGESTION = False # load, split, embed, and store documents

RAG_OFF        = "Off"
RAG_LANGCHAIN  = "LangChain"
RAG_LLAMAINDEX = "LlamaIndex"

config = {
    "chunk_overlap": 100,       # split documents
    "chunk_size": 2000,         # split documents
    "k": 2,                     # retrieve documents
    "model_name": "gpt-4-0314", # llm
    "temperature": 0            # llm
}

logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))

def invoke(openai_api_key, prompt, rag_option):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required.")
    if not prompt:
        raise gr.Error("Prompt is required.")
    if not rag_option:
        raise gr.Error("Retrieval-Augmented Generation is required.")

    with lock:
        os.environ["OPENAI_API_KEY"] = openai_api_key
        
        if (RAG_INGESTION):
            if (rag_option == RAG_LANGCHAIN):
                rag = LangChainRAG()
                rag.ingestion(config)
            #elif (rag_option == RAG_LLAMAINDEX):
            #    rag = LlamaIndexRAG()
            #    rag.ingestion(config)
    
        completion = ""
        result = ""
        callback = ""
        err_msg = ""
        
        try:
            start_time_ms = round(time.time() * 1000)
    
            if (rag_option == RAG_LANGCHAIN):
                rag = LangChainRAG()
                completion, callback = rag.rag_chain(config, prompt)
                result = completion["result"]
            #elif (rag_option == RAG_LLAMAINDEX):
            #    rag = LlamaIndexRAG()
            #    result, callback = rag.retrieval(config, prompt)
            else:
                rag = LangChainRAG()
                completion, callback = rag.llm_chain(config, prompt)
                result = completion.generations[0][0].text
        except Exception as e:
            err_msg = e
    
            raise gr.Error(e)
        finally:
            end_time_ms = round(time.time() * 1000)
            
            trace_wandb(
                config,
                rag_option,
                prompt, 
                completion, 
                result, 
                callback, 
                err_msg, 
                start_time_ms, 
                end_time_ms
            )

            del os.environ["OPENAI_API_KEY"]

        print("###")
        print(result)
        print("###")
        
        return result

gr.close_all()

demo = gr.Interface(
    fn = invoke, 
    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
              gr.Textbox(label = "Prompt", value = "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", lines = 1),
              gr.Radio([RAG_OFF, RAG_LANGCHAIN, RAG_LLAMAINDEX], label = "Retrieval-Augmented Generation", value = RAG_LANGCHAIN)],
    outputs = [gr.Textbox(label = "Completion", value = os.environ["COMPLETION"])],
    title = "Context-Aware Reasoning Application",
    description = os.environ["DESCRIPTION"],
    examples = [["sk-<BringYourOwn>", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "List GPT-4's exam scores and benchmark results.", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "Write a Python program that calls the GPT-4 API.", RAG_LANGCHAIN],
                ["sk-<BringYourOwn>", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_LANGCHAIN]],
               cache_examples = False
)

demo.launch()