web-crawling / main.py
pvanand's picture
Update main.py
10f3a01 verified
raw
history blame
6.38 kB
import os
import asyncio
from fastapi import FastAPI, HTTPException, Security, Depends, Query
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field, create_model
from typing import List, Optional
from crawl4ai import AsyncWebCrawler
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy, LLMExtractionStrategy
import json
import logging
import trafilatura
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
from file_conversion import router as file_conversion_router
app.include_router(file_conversion_router, prefix="/api/v1")
# API key configuration
CHAT_AUTH_KEY = os.getenv("CHAT_AUTH_KEY")
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != CHAT_AUTH_KEY:
logger.warning("Invalid API key used")
raise HTTPException(status_code=403, detail="Could not validate credentials")
return api_key
class CrawlerInput(BaseModel):
url: str = Field(..., description="URL to crawl")
columns: List[str] = Field(..., description="List of required columns")
descriptions: List[str] = Field(..., description="Descriptions for each column")
class CrawlerOutput(BaseModel):
data: List[dict]
async def simple_crawl(url: str):
async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun(url=url,
bypass_cache=True)
print(len(result.markdown))
return result
@app.post("/crawl", response_model=CrawlerOutput)
async def crawl(input: CrawlerInput, api_key: str = Depends(verify_api_key)):
if len(input.columns) != len(input.descriptions):
raise HTTPException(status_code=400, detail="Number of columns must match number of descriptions")
extraction_info = {col: desc for col, desc in zip(input.columns, input.descriptions)}
dynamic_model = create_model(
'DynamicModel',
**{col: (str, Field(..., description=desc)) for col, desc in extraction_info.items()}
)
instruction = f"Extract the following information: {json.dumps(extraction_info)}"
async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun(
url=input.url,
extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o-mini",
api_token=os.getenv('OPENAI_API_KEY'),
schema=dynamic_model.schema(),
extraction_type="schema",
verbose=True,
instruction=instruction
)
)
extracted_data = json.loads(result.extracted_content)
return CrawlerOutput(data=extracted_data)
@app.get("/basic-crawl")
async def test_url(api_key: str = Depends(verify_api_key), url: str = Query(..., description="URL to crawl")):
"""
A test endpoint that takes a URL as input and returns the result of crawling it.
"""
result = await simple_crawl(url=url)
return {"markdown": result.markdown}
@app.get("/basic-crawl-article")
async def extract_article(
url: str,
record_id: Optional[str] = Query(None, description="Add an ID to the metadata."),
no_fallback: Optional[bool] = Query(False, description="Skip the backup extraction with readability-lxml and justext."),
favor_precision: Optional[bool] = Query(False, description="Prefer less text but correct extraction."),
favor_recall: Optional[bool] = Query(False, description="When unsure, prefer more text."),
include_comments: Optional[bool] = Query(True, description="Extract comments along with the main text."),
output_format: Optional[str] = Query('txt', description="Define an output format: 'csv', 'json', 'markdown', 'txt', 'xml', 'xmltei'.", enum=["csv", "json", "markdown", "txt", "xml", "xmltei"]),
target_language: Optional[str] = Query(None, description="Define a language to discard invalid documents (ISO 639-1 format)."),
include_tables: Optional[bool] = Query(True, description="Take into account information within the HTML <table> element."),
include_images: Optional[bool] = Query(False, description="Take images into account (experimental)."),
include_links: Optional[bool] = Query(False, description="Keep links along with their targets (experimental)."),
deduplicate: Optional[bool] = Query(False, description="Remove duplicate segments and documents."),
max_tree_size: Optional[int] = Query(None, description="Discard documents with too many elements.")
):
response = await simple_crawl(url=url)
filecontent = response.html
extracted = trafilatura.extract(
filecontent,
url=url,
record_id=record_id,
no_fallback=no_fallback,
favor_precision=favor_precision,
favor_recall=favor_recall,
include_comments=include_comments,
output_format=output_format,
target_language=target_language,
include_tables=include_tables,
include_images=include_images,
include_links=include_links,
deduplicate=deduplicate,
max_tree_size=max_tree_size
)
if extracted:
return {"article": trafilatura.utils.sanitize(extracted)}
else:
return {"error": "Could not extract the article"}
@app.get("/test")
async def test(api_key: str = Depends(verify_api_key)):
result = await simple_crawl("https://www.nbcnews.com/business")
return {"markdown": result.markdown}
from fastapi.middleware.cors import CORSMiddleware
# CORS middleware setup
app.add_middleware(
CORSMiddleware,
#allow_origins=["*"],
allow_origins=[
"http://127.0.0.1:5501/",
"http://localhost:5501",
"http://localhost:3000",
"https://www.elevaticsai.com",
"https://www.elevatics.cloud",
"https://www.elevatics.online",
"https://www.elevatics.ai",
"https://elevaticsai.com",
"https://elevatics.cloud",
"https://elevatics.online",
"https://elevatics.ai",
"https://web.elevatics.cloud",
"https://pvanand-specialized-agents.hf.space",
"https://pvanand-audio-chat.hf.space/"
],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)