Spaces:
Running
Running
File size: 5,120 Bytes
39b7b6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import asyncio
from typing import Callable, Optional, Union
import huggingface_hub
import semchunk
import tiktoken
import tokenizers
from datasets import Dataset, concatenate_datasets, load_dataset
from rich.progress import track
from transformers import PreTrainedTokenizer
TOKENIZER_OR_TOKEN_COUNTER = Union[
str,
tiktoken.Encoding,
PreTrainedTokenizer,
tokenizers.Tokenizer,
Callable[[str], int],
]
class SemanticChunker:
"""
SemanticChunker is a class that chunks documents into smaller segments and
publishes them as datasets.
This class uses the `semchunk` library to break down large documents into
smaller, manageable chunks based on a specified tokenizer or token counter.
This is particularly useful for processing large text datasets where
smaller segments are needed for analysis or other operations.
!!! example "Example Usage"
```python
from medrag_multi_modal.semantic_chunking import SemanticChunker
chunker = SemanticChunker(chunk_size=256)
chunker.chunk(
document_dataset="geekyrakshit/grays-anatomy-test",
chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test",
)
```
Args:
tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or
token counter to be used for chunking.
chunk_size (Optional[int]): The size of each chunk. If not specified, the
default chunk size from `semchunk` will be used.
max_token_chars (Optional[int]): The maximum number of characters per token.
If not specified, the default value from `semchunk` will be used.
memoize (bool): Whether to memoize the chunking process for efficiency.
Default is True.
"""
def __init__(
self,
tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base",
chunk_size: Optional[int] = None,
max_token_chars: Optional[int] = None,
memoize: bool = True,
) -> None:
self.chunker = semchunk.chunkerify(
tokenizer_or_token_counter,
chunk_size=chunk_size,
max_token_chars=max_token_chars,
memoize=memoize,
)
def chunk(
self,
document_dataset: Union[Dataset, str],
chunk_dataset_repo_id: Optional[str] = None,
overwrite_dataset: bool = False,
) -> Dataset:
"""
Chunks a document dataset into smaller segments and publishes them as a new dataset.
This function takes a document dataset, either as a HuggingFace Dataset object or a string
representing the dataset repository ID, and chunks the documents into smaller segments using
the specified chunker. The resulting chunks are then optionally published to a HuggingFace
dataset repository.
Args:
document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either
a HuggingFace Dataset object or a string representing the dataset repository ID.
chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish
the chunks to, if provided. Defaults to None.
overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
Returns:
Dataset: A HuggingFace Dataset object containing the chunks.
"""
document_dataset = (
load_dataset(document_dataset, split="corpus")
if isinstance(document_dataset, str)
else document_dataset
).to_list()
chunks = []
async def process_document(idx, document):
document_chunks = self.chunker.chunk(str(document["text"]))
for chunk in document_chunks:
chunk_dict = {"document_idx": idx, "text": chunk}
for key, value in document.items():
if key not in chunk_dict:
chunk_dict[key] = value
chunks.append(chunk_dict)
async def process_all_documents():
tasks = []
for idx, document in track(
enumerate(document_dataset),
total=len(document_dataset),
description="Chunking documents",
):
tasks.append(process_document(idx, document))
await asyncio.gather(*tasks)
asyncio.run(process_all_documents())
chunks.sort(key=lambda x: x["document_idx"])
dataset = Dataset.from_list(chunks)
if chunk_dataset_repo_id:
if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"):
if not overwrite_dataset:
dataset = concatenate_datasets(
[
dataset,
load_dataset(chunk_dataset_repo_id, split="chunks"),
]
)
dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks")
return dataset
|