File size: 7,579 Bytes
7406911
 
9c3709d
 
7406911
9c3709d
 
4c66227
d803be1
9c3709d
8c28786
9c3709d
9847233
c7143b1
7e9684b
7402de3
9c3709d
 
 
 
 
 
8c28786
9c3709d
 
7406911
c7143b1
4c66227
6f80de5
c7143b1
7e9684b
7402de3
9c3709d
 
 
 
 
 
 
 
 
8d1e83e
d803be1
9c3709d
 
 
 
d594a38
 
8c28786
d803be1
c7143b1
9c3709d
7406911
9c3709d
7406911
9c3709d
 
542890e
7406911
542890e
 
6f80de5
542890e
 
6f80de5
542890e
 
 
 
 
 
 
 
6f80de5
 
 
 
 
 
542890e
7406911
9c3709d
7406911
8d1e83e
9c3709d
d594a38
9c3709d
7406911
d803be1
 
7406911
7402de3
8c28786
4c66227
d803be1
9c3709d
7406911
 
 
 
 
9c3709d
8d1e83e
7406911
d803be1
7406911
 
c7143b1
 
 
6f80de5
7406911
6f80de5
8d1e83e
7406911
7402de3
e5a770a
7406911
c7143b1
 
7406911
7402de3
7406911
9c3709d
c7143b1
 
 
 
8d1e83e
7406911
8d1e83e
c7143b1
8d1e83e
c7143b1
9c3709d
 
7406911
8d1e83e
 
 
d803be1
9c3709d
 
7406911
c7143b1
7406911
c7143b1
 
 
 
 
 
 
 
 
 
 
7406911
9c3709d
7406911
8c28786
 
d803be1
8c28786
 
d803be1
7406911
 
8c28786
7406911
 
 
 
 
 
 
 
 
d803be1
 
7406911
d803be1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
search_agent.py

Usage:
    search_agent.py
        [--domain=domain]
        [--provider=provider]
        [--model=model]
        [--embedding_model=model]
        [--temperature=temp]
        [--copywrite]
        [--max_pages=num]
        [--max_extracts=num]
        [--use_browser]
        [--output=text]
        [--verbose]
        SEARCH_QUERY
    search_agent.py --version

Options:
    -h --help                           Show this screen.
    --version                           Show version.
    -c --copywrite                      First produce a draft, review it and rewrite for a final text
    -d domain --domain=domain           Limit search to a specific domain
    -t temp --temperature=temp          Set the temperature of the LLM [default: 0.0]
    -m model --model=model              Use a specific model [default: hf:Qwen/Qwen2.5-72B-Instruct]
    -e model --embedding_model=model    Use an embedding model
    -n num --max_pages=num              Max number of pages to retrieve [default: 10]
    -x num --max_extracts=num           Max number of page extract to consider [default: 7]
    -b --use_browser                    Use browser to fetch content from the web [default: False]
    -o text --output=text               Output format (choices: text, markdown) [default: markdown]
    -v --verbose                        Print verbose output [default: False]

"""

import os

from docopt import docopt
import dotenv

from langchain.callbacks import LangChainTracer

from langsmith import Client, traceable

from rich.console import Console
from rich.markdown import Markdown

import web_rag as wr
import web_crawler as wc
import copywriter as cw
import models as md
import nlp_rag as nr

# Initialize console for rich text output
console = Console()
# Load environment variables from a .env file
dotenv.load_dotenv()

def get_selenium_driver():
    """Initialize and return a headless Selenium WebDriver for Chrome."""
    from selenium import webdriver
    from selenium.webdriver.chrome.options import Options
    from selenium.common.exceptions import WebDriverException

    chrome_options = Options()
    chrome_options.add_argument("--headless")
    chrome_options.add_argument("--disable-extensions")
    chrome_options.add_argument("--disable-gpu")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    chrome_options.add_argument("--remote-debugging-port=9222")
    chrome_options.add_argument('--blink-settings=imagesEnabled=false')
    chrome_options.add_argument("--window-size=1920,1080")

    try:
        driver = webdriver.Chrome(options=chrome_options)
        return driver
    except WebDriverException as e:
        print(f"Error creating Selenium WebDriver: {e}")
        return None

# Initialize callbacks list
callbacks = []
# Add LangChainTracer to callbacks if API key is set
if os.getenv("LANGCHAIN_API_KEY"):
    callbacks.append(
        LangChainTracer(client=Client())
    )

@traceable(run_type="tool", name="search_agent")
def main(arguments):
    """Main function to execute the search agent logic."""
    verbose = arguments["--verbose"]
    copywrite_mode = arguments["--copywrite"]
    model = arguments["--model"]
    embedding_model = arguments["--embedding_model"]
    temperature = float(arguments["--temperature"])
    domain = arguments["--domain"]
    max_pages = int(arguments["--max_pages"])
    max_extract = int(arguments["--max_extracts"])
    output = arguments["--output"]
    use_selenium = arguments["--use_browser"]
    query = arguments["SEARCH_QUERY"]

    # Get the language model based on the provided model name and temperature
    chat = md.get_model(model, temperature)
    
    # If no embedding model is provided, use spacy for semantic search
    if embedding_model is None:
        use_nlp = True
        nlp = nr.get_nlp_model()
    else:
        use_nlp = False 
        embedding_model = md.get_embedding_model(embedding_model)

    # Log model details if verbose mode is enabled
    if verbose:
        model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
        console.log(f"Using model: {model_name}")
        if not use_nlp:
            embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
            console.log(f"Using embedding model: {embedding_model_name}")

    # Optimize the search query
    with console.status(f"[bold green]Optimizing query for search: {query}"):
        optimized_search_query = wr.optimize_search_query(chat, query)
        if len(optimized_search_query) < 3:
            optimized_search_query = query
    console.log(f"Optimized search query: [bold blue]{optimized_search_query}")

    # Retrieve sources using the optimized query
    with console.status(
            f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
        ):
        sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
    console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")

    # Fetch content from the retrieved sources
    with console.status(
        f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
    ):
        contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
    console.log(f"Managed to extract content from {len(contents)} sources")

    # Process content using spaCy or embedding model
    if use_nlp:
        with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
            chunks = nr.recursive_split_documents(contents)
            console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
        with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
            relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
            console.log(f"Found {len(relevant_results)} relevant chunks")
        with console.status(f"[bold green]Writing content", spinner="growVertical"):
            draft = nr.query_rag(chat, query, relevant_results)
    else:
        with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
            vector_store = wc.vectorize(contents, embedding_model)
        with console.status("[bold green]Writing content", spinner='dots8Bit'):
            draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k=max_extract)

    # If copywrite mode is enabled, generate comments and final text
    if(copywrite_mode):
        with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
            comments = cw.generate_comments(chat, query, draft)

        with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
            final_text = cw.generate_final_text(chat, query, draft, comments)
    else:
        final_text = draft

    # Output the answer
    console.rule(f"[bold green]Response")
    if output == "text":
        console.print(final_text)
    else:
        console.print(Markdown(final_text))
    console.rule("[bold green]")

    return final_text

if __name__ == '__main__':
    # Parse command-line arguments and execute the main function
    arguments = docopt(__doc__, version='Search Agent 0.1')
    main(arguments)