Spaces:
Sleeping
Sleeping
File size: 7,647 Bytes
258cebf df527c8 d594a38 258cebf d594a38 258cebf d594a38 258cebf 9847233 d803be1 d594a38 df527c8 d594a38 258cebf df527c8 258cebf 9847233 542890e 9847233 d594a38 d803be1 9847233 d803be1 6f80de5 df527c8 d803be1 df527c8 d803be1 df527c8 d803be1 df527c8 d803be1 df527c8 d803be1 df527c8 9847233 d803be1 9847233 d803be1 6f80de5 2f49709 9847233 df527c8 d594a38 9847233 d594a38 258cebf d594a38 df527c8 d594a38 258cebf d803be1 df527c8 bda01ad d594a38 bda01ad df527c8 9847233 258cebf bda01ad d803be1 d594a38 bda01ad df527c8 bda01ad 9847233 d594a38 258cebf 6f80de5 d803be1 258cebf 9847233 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import datetime
import os
import dotenv
import streamlit as st
from langchain_core.tracers.langchain import LangChainTracer
from langchain.callbacks.base import BaseCallbackHandler
from langsmith.client import Client
import web_rag as wr
import web_crawler as wc
import copywriter as cw
import models as md
dotenv.load_dotenv()
ls_tracer = LangChainTracer(
project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
client=Client()
)
class StreamHandler(BaseCallbackHandler):
"""Stream handler that appends tokens to container."""
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs):
self.text += token
self.container.markdown(self.text)
def create_links_markdown(sources_list):
"""
Create a markdown string for each source in the provided JSON.
Args:
sources_list (list): A list of dictionaries representing the sources.
Each dictionary should have 'title', 'link', and 'snippet' keys.
Returns:
str: A markdown string with a bullet point for each source,
including the title linked to the URL and the snippet.
"""
markdown_list = []
for source in sources_list:
title = source['title']
link = source['link']
snippet = source['snippet']
markdown = f"- [{title}]({link})\n {snippet}"
markdown_list.append(markdown)
return "\n".join(markdown_list)
st.set_page_config(layout="wide")
st.title("π Simple Search Agent π¬")
if "models" not in st.session_state:
models = []
if os.getenv("FIREWORKS_API_KEY"):
models.append("fireworks")
if os.getenv("TOGETHER_API_KEY"):
models.append("together")
if os.getenv("COHERE_API_KEY"):
models.append("cohere")
if os.getenv("OPENAI_API_KEY"):
models.append("openai")
if os.getenv("GROQ_API_KEY"):
models.append("groq")
if os.getenv("OLLAMA_API_KEY"):
models.append("ollama")
if os.getenv("CREDENTIALS_PROFILE_NAME"):
models.append("bedrock")
st.session_state["models"] = models
with st.sidebar.expander("Options", expanded=False):
model_provider = st.selectbox("Model provider π§ ", st.session_state["models"])
temperature = st.slider("Model temperature π‘οΈ", 0.0, 1.0, 0.1, help="The higher the more creative")
max_pages = st.slider("Max pages to retrieve π", 1, 20, 10, help="How many web pages to retrive from the internet")
top_k_documents = st.slider("Nbr of doc extracts to consider π", 1, 20, 10, help="How many of the top extracts to consider")
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode βοΈ", value=False, help="First generate a draft, then comments and then rewrite")
with st.sidebar.expander("Links", expanded=False):
links_md = st.markdown("")
if reviewer_mode:
with st.sidebar.expander("Answer review", expanded=False):
st.caption("Draft")
draft_md = st.markdown("")
st.divider()
st.caption("Comments")
comments_md = st.markdown("")
st.divider()
st.caption("Comparaison")
comparaison_md = st.markdown("")
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
if message["role"] == "assistant" and 'message_id' in message:
st.download_button(
label="Download",
data=message["content"],
file_name=f"{message['message_id']}.txt",
mime="text/plain"
)
if prompt := st.chat_input("Enter you instructions..." ):
st.chat_message("user").write(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
chat = md.get_model(model_provider, temperature)
embedding_model = md.get_embedding_model(model_provider)
with st.status("Thinking", expanded=True):
st.write("I first need to do some research")
optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
st.write(f"I should search the web for: {optimize_search_query}")
sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
links_md.markdown(create_links_markdown(sources))
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
contents = wc.get_links_contents(sources, use_selenium=False)
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
if reviewer_mode:
st.write("Creating a draft")
draft_prompt = wr.build_rag_prompt(
chat, prompt, optimize_search_query,
vector_store, top_k=top_k_documents, callbacks=[ls_tracer])
draft = chat.invoke(draft_prompt, stream=False, config={ "callbacks": [ls_tracer]})
draft_md.markdown(draft.content)
st.write("Sending draft for review")
comments = cw.generate_comments(chat, prompt, draft, callbacks=[ls_tracer])
comments_md.markdown(comments)
st.write("Reviewing comments and generating final answer")
rag_prompt = cw.get_final_text_prompt(prompt, draft, comments)
else:
rag_prompt = wr.build_rag_prompt(
chat, prompt, optimize_search_query, vector_store,
top_k=top_k_documents, callbacks=[ls_tracer]
)
with st.chat_message("assistant"):
st_cb = StreamHandler(st.empty())
response = ""
for chunk in chat.stream(rag_prompt, config={"callbacks": [ls_tracer]}):
if isinstance(chunk, dict):
chunk_text = chunk.get('text') or chunk.get('content', '')
elif isinstance(chunk, str):
chunk_text = chunk
elif hasattr(chunk, 'content'):
chunk_text = chunk.content
else:
chunk_text = str(chunk)
if isinstance(chunk_text, list):
chunk_text = ' '.join(
item['text'] if isinstance(item, dict) and 'text' in item
else str(item)
for item in chunk_text if item is not None
)
elif chunk_text is not None:
chunk_text = str(chunk_text)
else:
continue
response += chunk_text
st_cb.on_llm_new_token(chunk_text)
response = response.strip()
message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
st.session_state.messages.append({"role": "assistant", "content": response})
if st.session_state.messages[-1]["role"] == "assistant":
st.download_button(
label="Download",
data=st.session_state.messages[-1]["content"],
file_name=f"{message_id}.txt",
mime="text/plain"
)
if reviewer_mode:
compare_prompt = cw.get_compare_texts_prompts(prompt, draft_text=draft, final_text=response)
result = chat.invoke(compare_prompt, stream=False, config={ "callbacks": [ls_tracer]})
comparaison_md.markdown(result.content) |