Spaces:
Sleeping
Sleeping
import requests | |
from bs4 import BeautifulSoup | |
from urllib.parse import urljoin, urlparse | |
import os | |
from transformers import BertModel, BertTokenizer | |
import torch | |
import numpy as np | |
import faiss | |
from concurrent.futures import ThreadPoolExecutor | |
from retrying import retry | |
import time | |
from ratelimit import limits, sleep_and_retry | |
import threading | |
# Global counters for URLs and FAISS index initialization | |
total_urls_crawled = 0 | |
index_file = 'faiss_index.bin' # FAISS index file path | |
# Set of visited URLs to prevent duplicates | |
visited_urls = set() | |
# Directory to save crawled URLs | |
urls_dir = 'crawled_urls' | |
os.makedirs(urls_dir, exist_ok=True) | |
urls_file = os.path.join(urls_dir, 'crawled_urls.txt') | |
# Initialize FAISS index | |
def initialize_faiss_index(dimension): | |
if os.path.exists(index_file): | |
os.remove(index_file) | |
print("Deleted previous FAISS index file.") | |
index = faiss.IndexFlatL2(dimension) | |
return index | |
# Initialize or load FAISS index | |
dimension = 768 # Dimension of BERT embeddings | |
index = initialize_faiss_index(dimension) | |
# Initialize tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
# Lock for thread-safe update of total_urls_crawled | |
lock = threading.Lock() | |
# Function to update and print live count of crawled URLs | |
def update_live_count(): | |
global total_urls_crawled | |
while True: | |
with lock: | |
print(f"\rURLs crawled: {total_urls_crawled}", end='') | |
time.sleep(1) # Update every second | |
# Start live count update thread | |
live_count_thread = threading.Thread(target=update_live_count, daemon=True) | |
live_count_thread.start() | |
# Function to save crawled URLs to a file | |
def save_crawled_urls(url): | |
with open(urls_file, 'a') as f: | |
f.write(f"{url}\n") | |
f.flush() # Flush buffer to ensure immediate write | |
os.fsync(f.fileno()) # Ensure write is flushed to disk | |
# Function to get all links from a webpage with retry mechanism and rate limiting | |
# Adjust calls and period based on website's rate limits | |
def get_links(url, domain): | |
global total_urls_crawled | |
links = [] | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
response = requests.get(url, headers=headers, timeout=50) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
for link in soup.find_all('a', href=True): | |
href = link['href'] | |
normalized_url = normalize_url(href, domain) | |
if normalized_url and normalized_url not in visited_urls: | |
links.append(normalized_url) | |
visited_urls.add(normalized_url) | |
with lock: | |
total_urls_crawled += 1 | |
save_crawled_urls(normalized_url) # Save crawled URL to file | |
# Convert text to BERT embeddings and add to FAISS index | |
try: | |
text = soup.get_text() | |
if text: | |
embeddings = convert_text_to_bert_embeddings(text, tokenizer, model) | |
index.add(np.array([embeddings])) | |
except Exception as e: | |
print(f"Error adding embeddings to FAISS index: {e}") | |
except requests.HTTPError as e: | |
if e.response.status_code == 404: | |
print(f"HTTP 404 Error: {e}") | |
else: | |
print(f"HTTP error occurred: {e}") | |
except requests.RequestException as e: | |
print(f"Error accessing {url}: {e}") | |
return links | |
# Function to normalize and validate URLs | |
def normalize_url(url, domain): | |
parsed_url = urlparse(url) | |
if not parsed_url.scheme: | |
url = urljoin(domain, url) | |
if url.startswith(domain): | |
return url | |
return None | |
# Function to recursively get all pages and collect links with retry mechanism and rate limiting | |
# Adjust calls and period based on website's rate limits | |
def crawl_site(base_url, domain, depth=0, max_depth=10): # Increased max_depth to 10 | |
if depth > max_depth or base_url in visited_urls: | |
return [] | |
visited_urls.add(base_url) | |
links = get_links(base_url, domain) | |
print(f"Crawled {len(links)} links from {base_url} at depth {depth}.") # Debugging info | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
response = requests.get(base_url, headers=headers, timeout=30) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
links_to_crawl = [] | |
for link in soup.find_all('a', href=True): | |
href = link['href'] | |
normalized_url = normalize_url(href, domain) | |
if normalized_url and normalized_url not in visited_urls: | |
links_to_crawl.append(normalized_url) | |
with ThreadPoolExecutor(max_workers=500) as executor: | |
results = executor.map(lambda url: crawl_site(url, domain, depth + 1, max_depth), links_to_crawl) | |
for result in results: | |
links.extend(result) | |
except requests.HTTPError as e: | |
if e.response.status_code == 404: | |
print(f"HTTP 404 Error: {e}") | |
else: | |
print(f"HTTP error occurred: {e}") | |
except requests.RequestException as e: | |
print(f"Error accessing {base_url}: {e}") | |
return links | |
# Function to convert text to BERT embeddings | |
def convert_text_to_bert_embeddings(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() # Average pool last layer's output | |
return embeddings | |
# Main process | |
def main(): | |
global total_urls_crawled | |
domain = 'https://go.drugbank.com/' # Replace with your new domain | |
start_url = 'https://go.drugbank.com/drugs/DB00001' # Replace with your starting URL | |
try: | |
# Save the FAISS index at the beginning of the execution | |
faiss.write_index(index, index_file) | |
print("Initial FAISS index saved.") | |
urls = crawl_site(start_url, domain) | |
print(f"\n\nFound {total_urls_crawled} URLs.") | |
# Save the FAISS index at the end of execution | |
faiss.write_index(index, index_file) | |
print("Final FAISS index saved.") | |
except Exception as e: | |
print(f"Exception encountered: {e}") | |
if __name__ == "__main__": | |
main() |