jin-e / streamlit_interface.py
hamxahbhattii's picture
added Jine
6330947
raw
history blame
1.69 kB
import streamlit as st
from streamlit_chat import message
import os
##### Importing JIN-e
from jine import Jine
from dotenv import load_dotenv
import os
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
DATA_DIRECTORY = os.getenv("DATA_DIRECTORY")
VECTOR_STORE_DIRECTORY = os.getenv("VECTOR_STORE_DIRCTORY")
VECTOR_STORE_CHECK = os.getenv("VECTOR_STORE_CHECK")
DEBUG = os.getenv("DEBUG")
USE_HYDE = os.getenv("USE_HYDE")
# Initialize Jine
@st.cache_resource()
def load_model():
jine = Jine(OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG,USE_HYDE)
jine.load_model()
return jine
jine =load_model()
import streamlit as st
from streamlit_chat import message
# st.set_page_config(
# page_title="JIN-e",
# page_icon=":robot:"
# )
# #
st.header("JIN-e")
st.markdown("Powered by People Analytics")
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
# def query(payload):
# response = requests.post(API_URL, headers=headers, json=payload)
# return response.json()
def get_text():
input_text = st.text_input("You: ","Hello, how are you?", key="input")
return input_text
user_input = get_text()
if user_input:
response = jine.chat(user_input)
st.session_state.past.append(user_input)
st.session_state.generated.append(response)
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])-1, -1, -1):
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))