from typing import Any, Dict, Iterator, List import requests from huggingface_hub import add_collection_item, create_collection from tqdm.auto import tqdm class DatasetSearchClient: def __init__( self, base_url: str = "https://librarian-bots-dataset-column-search-api.hf.space", ): self.base_url = base_url def search( self, columns: List[str], match_all: bool = False, page_size: int = 100 ) -> Iterator[Dict[str, Any]]: """ Search datasets using the provided API, automatically handling pagination. Args: columns (List[str]): List of column names to search for. match_all (bool, optional): If True, match all columns. If False, match any column. Defaults to False. page_size (int, optional): Number of results per page. Defaults to 100. Yields: Dict[str, Any]: Each dataset result from all pages. Raises: requests.RequestException: If there's an error with the HTTP request. ValueError: If the API returns an unexpected response format. """ page = 1 total_results = None while total_results is None or (page - 1) * page_size < total_results: params = { "columns": columns, "match_all": str(match_all).lower(), "page": page, "page_size": page_size, } try: response = requests.get(f"{self.base_url}/search", params=params) response.raise_for_status() data = response.json() if not {"total", "page", "page_size", "results"}.issubset(data.keys()): raise ValueError("Unexpected response format from the API") if total_results is None: total_results = data["total"] yield from data["results"] page += 1 except requests.RequestException as e: raise requests.RequestException( f"Error connecting to the API: {str(e)}" ) from e except ValueError as e: raise ValueError(f"Error processing API response: {str(e)}") from e # Create an instance of the client client = DatasetSearchClient() def update_collection_for_dataset( collection_name: str = None, dataset_columns: List[str] = None, collection_description: str = None, collection_namespace: str = None, ): if not collection_name: collection = create_collection( collection_name, exists_ok=True, description=collection_description ) else: collection = create_collection( collection_name, exists_ok=True, description=collection_description, namespace=collection_namespace, ) results = list( tqdm( client.search(dataset_columns, match_all=True), desc="Searching datasets...", leave=False, ) ) for result in tqdm(results, desc="Adding datasets to collection...", leave=False): try: add_collection_item( collection.slug, result["hub_id"], item_type="dataset", exists_ok=True ) except Exception as e: print( f"Error adding dataset {result['hub_id']} to collection {collection_name}: {str(e)}" ) return f"https://huggingface.co/collections/{collection.slug}" collections = [ { "dataset_columns": ["chosen", "rejected", "prompt"], "collection_description": "Datasets suitable for DPO based on having 'chosen', 'rejected', and 'prompt' columns. Created using librarian-bots/dataset-column-search-api", "collection_name": "Direct Preference Optimization Datasets", }, { "dataset_columns": ["image", "chosen", "rejected"], "collection_description": "Datasets suitable for Image Preference Optimization based on having 'image','chosen', and 'rejected' columns", "collection_name": "Image Preference Optimization Datasets", }, { "collection_name": "Alpaca Style Datasets", "dataset_columns": ["instruction", "input", "output"], "collection_description": "Datasets which follow the Alpaca Style format based on having 'instruction', 'input', and 'output' columns", }, ] # results = [ # update_collection_for_dataset(**collection, collection_namespace="librarian-bots") # for collection in collections # ] # print(results)