amirhosseinkarami commited on
Commit
abed76f
·
1 Parent(s): bd201c2

Create Preprocess.py

Browse files
Files changed (1) hide show
  1. Preprocess.py +188 -0
Preprocess.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import numpy as np
5
+ import os
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from transformers import BatchEncoding, PreTrainedTokenizerBase
10
+ import json
11
+
12
+ class ModelUtils :
13
+ def __init__(self, model_root) :
14
+ self.model_root = model_root
15
+ self.model_path = os.path.join(model_root, "model")
16
+ self.tokenizer_path = os.path.join(model_root, "tokenizer")
17
+
18
+ def download_model (self) :
19
+ BASE_MODEL = "HooshvareLab/bert-fa-zwnj-base"
20
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
21
+ model = AutoModel.from_pretrained(BASE_MODEL)
22
+
23
+ tokenizer.save_pretrained(self.tokenizer_path)
24
+ model.save_pretrained(self.model_path)
25
+
26
+ def make_dirs (self) :
27
+ if not os.path.isdir(self.model_root) :
28
+ os.mkdir(self.model_root)
29
+ if not os.path.isdir(self.model_path) :
30
+ os.mkdir(self.model_path)
31
+ if not os.path.isdir(self.tokenizer_path) :
32
+ os.mkdir(self.tokenizer_path)
33
+
34
+ class Preprocess :
35
+ def __init__(self, model_root) :
36
+ self.model_root = model_root
37
+ self.model_path = os.path.join(model_root, "model")
38
+ self.tokenizer_path = os.path.join(model_root, "tokenizer")
39
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
40
+
41
+ def vectorize (self, text) :
42
+ model = AutoModel.from_pretrained(self.model_path).to(self.device)
43
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
44
+ ids, masks = self.transform_single_text(text, tokenizer, 510, stride=510, minimal_chunk_length=0, maximal_text_length=None)
45
+ # ids = torch.cat(ids, dim=0)
46
+ # masks = torch.cat(masks, dim=0)
47
+ tokens = {'input_ids': ids.to(self.device), 'attention_mask': masks.to(self.device)}
48
+
49
+ output = model(**tokens)
50
+ last_hidden_states = output.last_hidden_state
51
+
52
+ # first token embedding of shape <1, hidden_size>
53
+ # first_token_embedding = last_hidden_states[:,0,:]
54
+
55
+ # pooled embedding of shape <1, hidden_size>
56
+ mean_pooled_embedding = last_hidden_states.mean(axis=1)
57
+
58
+ result = mean_pooled_embedding.flatten().cpu().detach().numpy()
59
+ # print(result.shape)
60
+ # print(result)
61
+ # Convert the list to JSON
62
+ json_data = json.dumps(result.tolist())
63
+
64
+ return json_data
65
+
66
+
67
+
68
+ def transform_list_of_texts(
69
+ self,
70
+ texts: list[str],
71
+ tokenizer: PreTrainedTokenizerBase,
72
+ chunk_size: int,
73
+ stride: int,
74
+ minimal_chunk_length: int,
75
+ maximal_text_length: Optional[int] = None,
76
+ ) -> BatchEncoding:
77
+ model_inputs = [
78
+ self.transform_single_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length)
79
+ for text in texts
80
+ ]
81
+ input_ids = [model_input[0] for model_input in model_inputs]
82
+ attention_mask = [model_input[1] for model_input in model_inputs]
83
+ tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
84
+ return input_ids, attention_mask
85
+
86
+
87
+ def transform_single_text(
88
+ self,
89
+ text: str,
90
+ tokenizer: PreTrainedTokenizerBase,
91
+ chunk_size: int,
92
+ stride: int,
93
+ minimal_chunk_length: int,
94
+ maximal_text_length: Optional[int],
95
+ ) -> tuple[Tensor, Tensor]:
96
+ """Transforms (the entire) text to model input of BERT model."""
97
+ if maximal_text_length:
98
+ tokens = self.tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
99
+ else:
100
+ tokens = self.tokenize_whole_text(text, tokenizer)
101
+ input_id_chunks, mask_chunks = self.split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
102
+ self.add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks)
103
+ self.add_padding_tokens(input_id_chunks, mask_chunks)
104
+ input_ids, attention_mask = self.stack_tokens_from_all_chunks(input_id_chunks, mask_chunks)
105
+ return input_ids, attention_mask
106
+
107
+
108
+ def tokenize_whole_text(self, text: str, tokenizer: PreTrainedTokenizerBase) -> BatchEncoding:
109
+ """Tokenizes the entire text without truncation and without special tokens."""
110
+ tokens = tokenizer(text, add_special_tokens=False, truncation=False, return_tensors="pt")
111
+ return tokens
112
+
113
+
114
+ def tokenize_text_with_truncation(
115
+ self, text: str, tokenizer: PreTrainedTokenizerBase, maximal_text_length: int
116
+ ) -> BatchEncoding:
117
+ """Tokenizes the text with truncation to maximal_text_length and without special tokens."""
118
+ tokens = tokenizer(
119
+ text, add_special_tokens=False, max_length=maximal_text_length, truncation=True, return_tensors="pt"
120
+ )
121
+ return tokens
122
+
123
+
124
+ def split_tokens_into_smaller_chunks(
125
+ self,
126
+ tokens: BatchEncoding,
127
+ chunk_size: int,
128
+ stride: int,
129
+ minimal_chunk_length: int,
130
+ ) -> tuple[list[Tensor], list[Tensor]]:
131
+ """Splits tokens into overlapping chunks with given size and stride."""
132
+ input_id_chunks = self.split_overlapping(tokens["input_ids"][0], chunk_size, stride, minimal_chunk_length)
133
+ mask_chunks = self.split_overlapping(tokens["attention_mask"][0], chunk_size, stride, minimal_chunk_length)
134
+ return input_id_chunks, mask_chunks
135
+
136
+
137
+ def add_special_tokens_at_beginning_and_end(self, input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None:
138
+ """
139
+ Adds special CLS token (token id = 101) at the beginning.
140
+ Adds SEP token (token id = 102) at the end of each chunk.
141
+ Adds corresponding attention masks equal to 1 (attention mask is boolean).
142
+ """
143
+ for i in range(len(input_id_chunks)):
144
+ # adding CLS (token id 101) and SEP (token id 102) tokens
145
+ input_id_chunks[i] = torch.cat([Tensor([101]), input_id_chunks[i], Tensor([102])])
146
+ # adding attention masks corresponding to special tokens
147
+ mask_chunks[i] = torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])])
148
+
149
+
150
+ def add_padding_tokens(self, input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None:
151
+ """Adds padding tokens (token id = 0) at the end to make sure that all chunks have exactly 512 tokens."""
152
+ for i in range(len(input_id_chunks)):
153
+ # get required padding length
154
+ pad_len = 512 - input_id_chunks[i].shape[0]
155
+ # check if tensor length satisfies required chunk size
156
+ if pad_len > 0:
157
+ # if padding length is more than 0, we must add padding
158
+ input_id_chunks[i] = torch.cat([input_id_chunks[i], Tensor([0] * pad_len)])
159
+ mask_chunks[i] = torch.cat([mask_chunks[i], Tensor([0] * pad_len)])
160
+
161
+
162
+ def stack_tokens_from_all_chunks(self, input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> tuple[Tensor, Tensor]:
163
+ """Reshapes data to a form compatible with BERT model input."""
164
+ input_ids = torch.stack(input_id_chunks)
165
+ attention_mask = torch.stack(mask_chunks)
166
+
167
+ return input_ids.long(), attention_mask.int()
168
+
169
+
170
+ def split_overlapping(self, tensor: Tensor, chunk_size: int, stride: int, minimal_chunk_length: int) -> list[Tensor]:
171
+ """Helper function for dividing 1-dimensional tensors into overlapping chunks."""
172
+ self.check_split_parameters_consistency(chunk_size, stride, minimal_chunk_length)
173
+ result = [tensor[i : i + chunk_size] for i in range(0, len(tensor), stride)]
174
+ if len(result) > 1:
175
+ # ignore chunks with less than minimal_length number of tokens
176
+ result = [x for x in result if len(x) >= minimal_chunk_length]
177
+ return result
178
+
179
+
180
+ def check_split_parameters_consistency(self, chunk_size: int, stride: int, minimal_chunk_length: int) -> None:
181
+ if chunk_size > 510:
182
+ raise RuntimeError("Size of each chunk cannot be bigger than 510!")
183
+ if minimal_chunk_length > chunk_size:
184
+ raise RuntimeError("Minimal length cannot be bigger than size!")
185
+ if stride > chunk_size:
186
+ raise RuntimeError(
187
+ "Stride cannot be bigger than size! Chunks must overlap or be near each other!"
188
+ )