brianjking commited on
Commit
f9fdaeb
·
1 Parent(s): 57025a9

Create multiple.py

Browse files
Files changed (1) hide show
  1. multiple.py +83 -0
multiple.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ from llama_index import (
5
+ ServiceContext,
6
+ SimpleDirectoryReader,
7
+ VectorStoreIndex,
8
+ )
9
+ from llama_index.llms import OpenAI
10
+ import openai
11
+
12
+ st.title("Grounded Generation")
13
+
14
+ uploaded_files = st.file_uploader("Choose PDF files", type="pdf", accept_multiple_files=True)
15
+
16
+ @st.cache_resource(show_spinner=False)
17
+ def load_data(uploaded_files):
18
+ with st.spinner('Indexing documents...'):
19
+ temp_dir = tempfile.mkdtemp() # Create temporary directory
20
+ file_paths = [] # List to store paths of saved files
21
+
22
+ # Save the uploaded files temporarily
23
+ for i, uploaded_file in enumerate(uploaded_files):
24
+ temp_path = os.path.join(temp_dir, f"temp_{i}.pdf")
25
+ with open(temp_path, "wb") as f:
26
+ f.write(uploaded_file.read())
27
+ file_paths.append(temp_path)
28
+
29
+ # Read and index documents using SimpleDirectoryReader
30
+ reader = SimpleDirectoryReader(input_dir=temp_dir, recursive=False)
31
+ docs = reader.load_data()
32
+ service_context = ServiceContext.from_defaults(
33
+ llm=OpenAI(
34
+ model="gpt-3.5-turbo-16k",
35
+ temperature=0.1,
36
+ ),
37
+ system_prompt="You are an AI assistant that uses context from PDFs to assist the user in generating text."
38
+ )
39
+ index = VectorStoreIndex.from_documents(docs, service_context=service_context)
40
+
41
+ # Clean up temporary files and directory
42
+ for file_path in file_paths:
43
+ os.remove(file_path)
44
+ os.rmdir(temp_dir)
45
+
46
+ return index
47
+
48
+ if uploaded_files:
49
+ index = load_data(uploaded_files)
50
+
51
+ user_query = st.text_input("Search for the products/info you want to use to ground your generated text content:")
52
+
53
+ if 'retrieved_text' not in st.session_state:
54
+ st.session_state['retrieved_text'] = ''
55
+
56
+ if st.button("Retrieve"):
57
+ with st.spinner('Retrieving text...'):
58
+ query_engine = index.as_query_engine(similarity_top_k=1)
59
+ st.session_state['retrieved_text'] = query_engine.query(user_query)
60
+ st.write(f"Retrieved Text: {st.session_state['retrieved_text']}")
61
+
62
+ content_type = st.selectbox("Select content type:", ["Blog", "Tweet"])
63
+
64
+ if st.button("Generate") and content_type:
65
+ with st.spinner('Generating text...'):
66
+ openai.api_key = os.getenv("OPENAI_API_KEY")
67
+ try:
68
+ if content_type == "Blog":
69
+ prompt = f"Write a blog about 500 words in length using the {st.session_state['retrieved_text']}"
70
+ elif content_type == "Tweet":
71
+ prompt = f"Compose a tweet using the {st.session_state['retrieved_text']}"
72
+ response = openai.ChatCompletion.create(
73
+ model="gpt-3.5-turbo-16k",
74
+ messages=[
75
+ {"role": "system", "content": "You are a helpful assistant."},
76
+ {"role": "user", "content": prompt}
77
+ ]
78
+ )
79
+ generated_text = response['choices'][0]['message']['content']
80
+ st.write(f"Generated Text: {generated_text}")
81
+ except Exception as e:
82
+ st.write(f"An error occurred: {e}")
83
+