Spaces:
Running
Running
from functools import lru_cache | |
from pathlib import Path | |
import subprocess, os | |
import shutil | |
import tarfile | |
from .model_loader import * | |
import argparse | |
import urllib.request | |
from crawl4ai.config import MODEL_REPO_BRANCH | |
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) | |
def get_available_memory(device): | |
import torch | |
if device.type == 'cuda': | |
return torch.cuda.get_device_properties(device).total_memory | |
elif device.type == 'mps': | |
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate | |
else: | |
return 0 | |
def calculate_batch_size(device): | |
available_memory = get_available_memory(device) | |
if device.type == 'cpu': | |
return 16 | |
elif device.type in ['cuda', 'mps']: | |
# Adjust these thresholds based on your model size and available memory | |
if available_memory >= 31 * 1024 ** 3: # > 32GB | |
return 256 | |
elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB | |
return 128 | |
elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB | |
return 64 | |
else: | |
return 32 | |
else: | |
return 16 # Default batch size | |
def get_device(): | |
import torch | |
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
elif torch.backends.mps.is_available(): | |
device = torch.device('mps') | |
else: | |
device = torch.device('cpu') | |
return device | |
def set_model_device(model): | |
device = get_device() | |
model.to(device) | |
return model, device | |
def get_home_folder(): | |
home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") | |
os.makedirs(home_folder, exist_ok=True) | |
os.makedirs(f"{home_folder}/cache", exist_ok=True) | |
os.makedirs(f"{home_folder}/models", exist_ok=True) | |
return home_folder | |
def load_bert_base_uncased(): | |
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) | |
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) | |
model.eval() | |
model, device = set_model_device(model) | |
return tokenizer, model | |
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple: | |
"""Load the Hugging Face model for embedding. | |
Args: | |
model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5". | |
Returns: | |
tuple: The tokenizer and model. | |
""" | |
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel | |
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None) | |
model = AutoModel.from_pretrained(model_name, resume_download=None) | |
model.eval() | |
model, device = set_model_device(model) | |
return tokenizer, model | |
def load_text_classifier(): | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from transformers import pipeline | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") | |
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") | |
model.eval() | |
model, device = set_model_device(model) | |
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
return pipe | |
def load_text_multilabel_classifier(): | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import numpy as np | |
from scipy.special import expit | |
import torch | |
# # Check for available device: CUDA, MPS (for Apple Silicon), or CPU | |
# if torch.cuda.is_available(): | |
# device = torch.device("cuda") | |
# elif torch.backends.mps.is_available(): | |
# device = torch.device("mps") | |
# else: | |
# device = torch.device("cpu") | |
# # return load_spacy_model(), torch.device("cpu") | |
MODEL = "cardiffnlp/tweet-topic-21-multi" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) | |
model.eval() | |
model, device = set_model_device(model) | |
class_mapping = model.config.id2label | |
def _classifier(texts, threshold=0.5, max_length=64): | |
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length) | |
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device | |
with torch.no_grad(): | |
output = model(**tokens) | |
scores = output.logits.detach().cpu().numpy() | |
scores = expit(scores) | |
predictions = (scores >= threshold) * 1 | |
batch_labels = [] | |
for prediction in predictions: | |
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1] | |
batch_labels.append(labels) | |
return batch_labels | |
return _classifier, device | |
def load_nltk_punkt(): | |
import nltk | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
return nltk.data.find('tokenizers/punkt') | |
def load_spacy_model(): | |
import spacy | |
name = "models/reuters" | |
home_folder = get_home_folder() | |
model_folder = Path(home_folder) / name | |
# Check if the model directory already exists | |
if not (model_folder.exists() and any(model_folder.iterdir())): | |
repo_url = "https://github.com/unclecode/crawl4ai.git" | |
branch = MODEL_REPO_BRANCH | |
repo_folder = Path(home_folder) / "crawl4ai" | |
print("[LOG] ⏬ Downloading Spacy model for the first time...") | |
# Remove existing repo folder if it exists | |
if repo_folder.exists(): | |
try: | |
shutil.rmtree(repo_folder) | |
if model_folder.exists(): | |
shutil.rmtree(model_folder) | |
except PermissionError: | |
print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:") | |
print(f"- {repo_folder}") | |
print(f"- {model_folder}") | |
return None | |
try: | |
# Clone the repository | |
subprocess.run( | |
["git", "clone", "-b", branch, repo_url, str(repo_folder)], | |
stdout=subprocess.DEVNULL, | |
stderr=subprocess.DEVNULL, | |
check=True | |
) | |
# Create the models directory if it doesn't exist | |
models_folder = Path(home_folder) / "models" | |
models_folder.mkdir(parents=True, exist_ok=True) | |
# Copy the reuters model folder to the models directory | |
source_folder = repo_folder / "models" / "reuters" | |
shutil.copytree(source_folder, model_folder) | |
# Remove the cloned repository | |
shutil.rmtree(repo_folder) | |
print("[LOG] ✅ Spacy Model downloaded successfully") | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred while cloning the repository: {e}") | |
return None | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return None | |
try: | |
return spacy.load(str(model_folder)) | |
except Exception as e: | |
print(f"Error loading spacy model: {e}") | |
return None | |
def download_all_models(remove_existing=False): | |
"""Download all models required for Crawl4AI.""" | |
if remove_existing: | |
print("[LOG] Removing existing models...") | |
home_folder = get_home_folder() | |
model_folders = [ | |
os.path.join(home_folder, "models/reuters"), | |
os.path.join(home_folder, "models"), | |
] | |
for folder in model_folders: | |
if Path(folder).exists(): | |
shutil.rmtree(folder) | |
print("[LOG] Existing models removed.") | |
# Load each model to trigger download | |
# print("[LOG] Downloading BERT Base Uncased...") | |
# load_bert_base_uncased() | |
# print("[LOG] Downloading BGE Small EN v1.5...") | |
# load_bge_small_en_v1_5() | |
# print("[LOG] Downloading ONNX model...") | |
# load_onnx_all_MiniLM_l6_v2() | |
print("[LOG] Downloading text classifier...") | |
_, device = load_text_multilabel_classifier() | |
print(f"[LOG] Text classifier loaded on {device}") | |
print("[LOG] Downloading custom NLTK Punkt model...") | |
load_nltk_punkt() | |
print("[LOG] ✅ All models downloaded successfully.") | |
def main(): | |
print("[LOG] Welcome to the Crawl4AI Model Downloader!") | |
print("[LOG] This script will download all the models required for Crawl4AI.") | |
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader") | |
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading") | |
args = parser.parse_args() | |
download_all_models(remove_existing=args.remove_existing) | |
if __name__ == "__main__": | |
main() | |