Spaces:
Sleeping
Sleeping
File size: 6,125 Bytes
9c3709d 4c66227 d803be1 9c3709d 8c28786 9c3709d 9847233 d803be1 7e9684b 7402de3 9c3709d 8c28786 9c3709d d803be1 6f80de5 4c66227 6f80de5 d803be1 7e9684b 7402de3 9c3709d 8c28786 9c3709d 8d1e83e d803be1 9c3709d d594a38 8c28786 d803be1 9c3709d 542890e 6f80de5 542890e 6f80de5 542890e 6f80de5 542890e 9c3709d 8d1e83e 9c3709d d594a38 9c3709d d803be1 7402de3 8c28786 4c66227 d803be1 9c3709d 8d1e83e d803be1 9847233 7e9684b d803be1 9c3709d 8d1e83e d803be1 6f80de5 8d1e83e 7402de3 9c3709d d803be1 9847233 8d1e83e d594a38 9c3709d 8d1e83e d803be1 9c3709d 9847233 df527c8 9c3709d 9847233 d803be1 9c3709d d803be1 7e9684b 8c28786 7e9684b 8c28786 9c3709d 8c28786 d803be1 8c28786 9847233 8c28786 d803be1 8c28786 9847233 8c28786 d803be1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""search_agent.py
Usage:
search_agent.py
[--domain=domain]
[--provider=provider]
[--model=model]
[--embedding_model=model]
[--temperature=temp]
[--copywrite]
[--max_pages=num]
[--max_extracts=num]
[--use_selenium]
[--output=text]
[--verbose]
SEARCH_QUERY
search_agent.py --version
Options:
-h --help Show this screen.
--version Show version.
-c --copywrite First produce a draft, review it and rewrite for a final text
-d domain --domain=domain Limit search to a specific domain
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
-m model --model=model Use a specific model [default: openai/gpt-4o-mini]
-e model --embedding_model=model Use a specific embedding model [default: same provider as model]
-n num --max_pages=num Max number of pages to retrieve [default: 10]
-x num --max_extracts=num Max number of page extract to consider [default: 7]
-s --use_selenium Use selenium to fetch content from the web [default: False]
-o text --output=text Output format (choices: text, markdown) [default: markdown]
-v --verbose Print verbose output [default: False]
"""
import os
from docopt import docopt
#from schema import Schema, Use, SchemaError
import dotenv
from langchain.callbacks import LangChainTracer
from langsmith import Client, traceable
from rich.console import Console
from rich.markdown import Markdown
import web_rag as wr
import web_crawler as wc
import copywriter as cw
import models as md
console = Console()
dotenv.load_dotenv()
def get_selenium_driver():
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.common.exceptions import WebDriverException
chrome_options = Options()
chrome_options.add_argument("--headless")
chrome_options.add_argument("--disable-extensions")
chrome_options.add_argument("--disable-gpu")
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
chrome_options.add_argument("--remote-debugging-port=9222")
chrome_options.add_argument('--blink-settings=imagesEnabled=false')
chrome_options.add_argument("--window-size=1920,1080")
try:
driver = webdriver.Chrome(options=chrome_options)
return driver
except WebDriverException as e:
print(f"Error creating Selenium WebDriver: {e}")
return None
callbacks = []
if os.getenv("LANGCHAIN_API_KEY"):
callbacks.append(
LangChainTracer(client=Client())
)
@traceable(run_type="tool", name="search_agent")
def main(arguments):
verbose = arguments["--verbose"]
copywrite_mode = arguments["--copywrite"]
model = arguments["--model"]
embedding_model = arguments["--embedding_model"]
temperature = float(arguments["--temperature"])
domain=arguments["--domain"]
max_pages=int(arguments["--max_pages"])
max_extract=int(arguments["--max_extracts"])
output=arguments["--output"]
use_selenium=arguments["--use_selenium"]
query = arguments["SEARCH_QUERY"]
chat = md.get_model(model, temperature)
if embedding_model.lower() == "same provider as model":
provider = model.split('/')[0]
embedding_model = md.get_embedding_model(f"{provider}/")
else:
embedding_model = md.get_embedding_model(embedding_model)
if verbose:
console.log(f"Using model: {chat.model_name}")
console.log(f"Using embedding model: { embedding_model.model}")
with console.status(f"[bold green]Optimizing query for search: {query}"):
optimize_search_query = wr.optimize_search_query(chat, query)
if len(optimize_search_query) < 3:
optimize_search_query = query
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
with console.status(
f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
):
sources = wc.get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
with console.status(
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
):
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
console.log(f"Managed to extract content from {len(contents)} sources")
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
vector_store = wc.vectorize(contents, embedding_model)
with console.status("[bold green]Writing content", spinner='dots8Bit'):
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract)
console.rule(f"[bold green]Response")
if output == "text":
console.print(draft)
else:
console.print(Markdown(draft))
console.rule("[bold green]")
if(copywrite_mode):
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
comments = cw.generate_comments(chat, query, draft)
console.rule("[bold green]Response from reviewer")
if output == "text":
console.print(comments)
else:
console.print(Markdown(comments))
console.rule("[bold green]")
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
final_text = cw.generate_final_text(chat, query, draft, comments)
console.rule("[bold green]Final text")
if output == "text":
console.print(final_text)
else:
console.print(Markdown(final_text))
console.rule("[bold green]")
if __name__ == '__main__':
arguments = docopt(__doc__, version='Search Agent 0.1')
main(arguments)
|