File size: 5,120 Bytes
39b7b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import asyncio
from typing import Callable, Optional, Union

import huggingface_hub
import semchunk
import tiktoken
import tokenizers
from datasets import Dataset, concatenate_datasets, load_dataset
from rich.progress import track
from transformers import PreTrainedTokenizer

TOKENIZER_OR_TOKEN_COUNTER = Union[
    str,
    tiktoken.Encoding,
    PreTrainedTokenizer,
    tokenizers.Tokenizer,
    Callable[[str], int],
]


class SemanticChunker:
    """
    SemanticChunker is a class that chunks documents into smaller segments and
    publishes them as datasets.

    This class uses the `semchunk` library to break down large documents into
    smaller, manageable chunks based on a specified tokenizer or token counter.
    This is particularly useful for processing large text datasets where
    smaller segments are needed for analysis or other operations.

    !!! example "Example Usage"
        ```python
        from medrag_multi_modal.semantic_chunking import SemanticChunker


        chunker = SemanticChunker(chunk_size=256)
        chunker.chunk(
            document_dataset="geekyrakshit/grays-anatomy-test",
            chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test",
        )
        ```

    Args:
        tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or
            token counter to be used for chunking.
        chunk_size (Optional[int]): The size of each chunk. If not specified, the
            default chunk size from `semchunk` will be used.
        max_token_chars (Optional[int]): The maximum number of characters per token.
            If not specified, the default value from `semchunk` will be used.
        memoize (bool): Whether to memoize the chunking process for efficiency.
            Default is True.
    """

    def __init__(
        self,
        tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base",
        chunk_size: Optional[int] = None,
        max_token_chars: Optional[int] = None,
        memoize: bool = True,
    ) -> None:
        self.chunker = semchunk.chunkerify(
            tokenizer_or_token_counter,
            chunk_size=chunk_size,
            max_token_chars=max_token_chars,
            memoize=memoize,
        )

    def chunk(
        self,
        document_dataset: Union[Dataset, str],
        chunk_dataset_repo_id: Optional[str] = None,
        overwrite_dataset: bool = False,
    ) -> Dataset:
        """
        Chunks a document dataset into smaller segments and publishes them as a new dataset.

        This function takes a document dataset, either as a HuggingFace Dataset object or a string
        representing the dataset repository ID, and chunks the documents into smaller segments using
        the specified chunker. The resulting chunks are then optionally published to a HuggingFace
        dataset repository.

        Args:
            document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either
                a HuggingFace Dataset object or a string representing the dataset repository ID.
            chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish
                the chunks to, if provided. Defaults to None.
            overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.

        Returns:
            Dataset: A HuggingFace Dataset object containing the chunks.
        """
        document_dataset = (
            load_dataset(document_dataset, split="corpus")
            if isinstance(document_dataset, str)
            else document_dataset
        ).to_list()

        chunks = []

        async def process_document(idx, document):
            document_chunks = self.chunker.chunk(str(document["text"]))
            for chunk in document_chunks:
                chunk_dict = {"document_idx": idx, "text": chunk}
                for key, value in document.items():
                    if key not in chunk_dict:
                        chunk_dict[key] = value
                chunks.append(chunk_dict)

        async def process_all_documents():
            tasks = []
            for idx, document in track(
                enumerate(document_dataset),
                total=len(document_dataset),
                description="Chunking documents",
            ):
                tasks.append(process_document(idx, document))
            await asyncio.gather(*tasks)

        asyncio.run(process_all_documents())

        chunks.sort(key=lambda x: x["document_idx"])

        dataset = Dataset.from_list(chunks)
        if chunk_dataset_repo_id:
            if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"):
                if not overwrite_dataset:
                    dataset = concatenate_datasets(
                        [
                            dataset,
                            load_dataset(chunk_dataset_repo_id, split="chunks"),
                        ]
                    )
            dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks")

        return dataset