File size: 3,272 Bytes
7d6d701
7ddcfd9
7d6d701
99bbf81
a627434
44a256c
004cf23
eb978fe
7ddcfd9
eb978fe
44a256c
eb978fe
7ddcfd9
3ddc880
51605e1
5b9fc25
7ddcfd9
08b6d98
f6fcf7f
eb1e69b
 
f6fcf7f
917e125
ebcdcac
044c0a3
ebcdcac
044c0a3
917e125
 
7ddcfd9
7c151aa
 
44a256c
 
1e517cc
acf522c
26b6a5b
ddfaa69
c6bc22c
12d440a
1e517cc
1283168
9102fcd
99bbf81
7ddcfd9
cd4364b
99bbf81
ddfaa69
 
7ddcfd9
cd4364b
09018fc
7ddcfd9
1283168
12d440a
7ddcfd9
c2e6078
37ab520
043b829
99bbf81
004cf23
7ddcfd9
 
 
 
 
9294813
7ddcfd9
 
 
44a256c
8d60a3f
7d6d701
 
7ddcfd9
1f3b512
14e92f6
99bbf81
14e92f6
b7d5b27
c42c62b
4c13268
621d4dd
 
 
 
 
4c13268
7ddcfd9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import os, time

from dotenv import load_dotenv, find_dotenv

from rag import llm_chain, rag_chain, rag_ingestion
from trace import trace_wandb

_ = load_dotenv(find_dotenv())

RAG_INGESTION = False # load, split, embed, and store documents

config = {
    "k": 3,                     # retrieve documents
    "model_name": "gpt-4-0314", # llm
    "temperature": 0            # llm
}

RAG_OFF     = "Off"
RAG_MONGODB = "MongoDB" # serverless
RAG_CHROMA  = "Chroma"  # requires persistent storage (small is $0.01/hour)

def invoke(openai_api_key, prompt, rag_option):
    if (openai_api_key == ""):
        raise gr.Error("OpenAI API Key is required.")
    if (prompt == ""):
        raise gr.Error("Prompt is required.")
    if (rag_option is None):
        raise gr.Error("Retrieval Augmented Generation is required.")

    os.environ["OPENAI_API_KEY"] = openai_api_key
    
    if (RAG_INGESTION):
        rag_ingestion(config)
    
    chain = None
    completion = ""
    result = ""
    cb = ""
    err_msg = ""
    
    try:
        start_time_ms = round(time.time() * 1000)

        if (rag_option == RAG_OFF):
            completion, chain, cb = llm_chain(config, prompt)
            
            if (completion.generations[0] != None and completion.generations[0][0] != None):
                result = completion.generations[0][0].text
        else:
            completion, chain, cb = rag_chain(config, rag_option, prompt)

            result = completion["result"]
    except Exception as e:
        err_msg = e

        raise gr.Error(e)
    finally:
        end_time_ms = round(time.time() * 1000)
        
        trace_wandb(config,
                    rag_option == RAG_OFF, 
                    prompt, 
                    completion, 
                    result, 
                    chain, 
                    cb, 
                    err_msg, 
                    start_time_ms, 
                    end_time_ms)

    return result

gr.close_all()

demo = gr.Interface(fn = invoke, 
                    inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), 
                              gr.Textbox(label = "Prompt", value = "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", lines = 1),
                              gr.Radio([RAG_OFF, RAG_MONGODB], label = "Retrieval-Augmented Generation", value = RAG_MONGODB)],
                    outputs = [gr.Textbox(label = "Completion", lines = 1)],
                    title = "Context-Aware Reasoning Application",
                    description = os.environ["DESCRIPTION"],
                    examples = [["", "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", RAG_MONGODB],
                                ["", "List GPT-4's exam scores and benchmark results.", RAG_MONGODB],
                                ["", "Compare GPT-4 to GPT-3.5 in markdown table format.", RAG_MONGODB],
                                ["", "Write a Python program that calls the GPT-4 API.", RAG_MONGODB],
                                ["", "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format.", RAG_MONGODB]],
                                cache_examples = False)

demo.launch()