Tom Aarsen commited on
Commit
24e3585
·
1 Parent(s): 4b92f32

Heavily simplify app, rely on gliner on PyPI

Browse files
GLiNER/README.md DELETED
@@ -1,74 +0,0 @@
1
- # Model Card for GLiNER-base
2
-
3
- GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios.
4
-
5
- ## Links
6
-
7
- * Paper: https://arxiv.org/abs/2311.08526
8
- * Repository: https://github.com/urchade/GLiNER
9
-
10
- ## Installation
11
- To use this model, you must download the GLiNER repository and install its dependencies:
12
- ```
13
- !git clone https://github.com/urchade/GLiNER.git
14
- %cd GLiNER
15
- !pip install -r requirements.txt
16
- ```
17
-
18
- ## Usage
19
- Once you've downloaded the GLiNER repository, you can import the GLiNER class from the `model` file. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
20
-
21
- ```python
22
- from model import GLiNER
23
-
24
- model = GLiNER.from_pretrained("urchade/gliner_base")
25
-
26
- text = """
27
- Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
28
- """
29
-
30
- labels = ["person", "award", "date", "competitions", "teams"]
31
-
32
- entities = model.predict_entities(text, labels)
33
-
34
- for entity in entities:
35
- print(entity["text"], "=>", entity["label"])
36
- ```
37
-
38
- ```
39
- Cristiano Ronaldo dos Santos Aveiro => person
40
- 5 February 1985 => date
41
- Al Nassr => teams
42
- Portugal national team => teams
43
- Ballon d'Or => award
44
- UEFA Men's Player of the Year Awards => award
45
- European Golden Shoes => award
46
- UEFA Champions Leagues => competitions
47
- UEFA European Championship => competitions
48
- UEFA Nations League => competitions
49
- Champions League => competitions
50
- European Championship => competitions
51
- ```
52
-
53
- ## Named Entity Recognition benchmark result
54
-
55
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6317233cc92fd6fee317e030/Y5f7tK8lonGqeeO6L6bVI.png)
56
-
57
- ## Model Authors
58
- The model authors are:
59
- * [Urchade Zaratiana](https://huggingface.co/urchade)
60
- * Nadi Tomeh
61
- * Pierre Holat
62
- * Thierry Charnois
63
-
64
- ## Citation
65
- ```bibtex
66
- @misc{zaratiana2023gliner,
67
- title={GLiNER: Generalist Model for Named Entity Recognition using Bidirectional Transformer},
68
- author={Urchade Zaratiana and Nadi Tomeh and Pierre Holat and Thierry Charnois},
69
- year={2023},
70
- eprint={2311.08526},
71
- archivePrefix={arXiv},
72
- primaryClass={cs.CL}
73
- }
74
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/model.py DELETED
@@ -1,412 +0,0 @@
1
- import argparse
2
- import json
3
- from pathlib import Path
4
- import re
5
- from typing import Dict, Optional, Union
6
- import torch
7
- import torch.nn.functional as F
8
- from modules.layers import LstmSeq2SeqEncoder
9
- from modules.base import InstructBase
10
- from modules.evaluator import Evaluator, greedy_search
11
- from modules.span_rep import SpanRepLayer
12
- from modules.token_rep import TokenRepLayer
13
- from torch import nn
14
- from torch.nn.utils.rnn import pad_sequence
15
- from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
16
- from huggingface_hub.utils import HfHubHTTPError
17
-
18
-
19
-
20
- class GLiNER(InstructBase, PyTorchModelHubMixin):
21
- def __init__(self, config):
22
- super().__init__(config)
23
-
24
- self.config = config
25
-
26
- # [ENT] token
27
- self.entity_token = "<<ENT>>"
28
- self.sep_token = "<<SEP>>"
29
-
30
- # usually a pretrained bidirectional transformer, returns first subtoken representation
31
- self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
32
- subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
33
- add_tokens=[self.entity_token, self.sep_token])
34
-
35
- # hierarchical representation of tokens
36
- self.rnn = LstmSeq2SeqEncoder(
37
- input_size=config.hidden_size,
38
- hidden_size=config.hidden_size // 2,
39
- num_layers=1,
40
- bidirectional=True,
41
- )
42
-
43
- # span representation
44
- self.span_rep_layer = SpanRepLayer(
45
- span_mode=config.span_mode,
46
- hidden_size=config.hidden_size,
47
- max_width=config.max_width,
48
- dropout=config.dropout,
49
- )
50
-
51
- # prompt representation (FFN)
52
- self.prompt_rep_layer = nn.Sequential(
53
- nn.Linear(config.hidden_size, config.hidden_size * 4),
54
- nn.Dropout(config.dropout),
55
- nn.ReLU(),
56
- nn.Linear(config.hidden_size * 4, config.hidden_size)
57
- )
58
-
59
- def compute_score_train(self, x):
60
- span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
61
-
62
- new_length = x['seq_length'].clone()
63
- new_tokens = []
64
- all_len_prompt = []
65
- num_classes_all = []
66
-
67
- # add prompt to the tokens
68
- for i in range(len(x['tokens'])):
69
- all_types_i = list(x['classes_to_id'][i].keys())
70
- # multiple entity types in all_types. Prompt is appended at the start of tokens
71
- entity_prompt = []
72
- num_classes_all.append(len(all_types_i))
73
- # add enity types to prompt
74
- for entity_type in all_types_i:
75
- entity_prompt.append(self.entity_token) # [ENT] token
76
- entity_prompt.append(entity_type) # entity type
77
- entity_prompt.append(self.sep_token) # [SEP] token
78
-
79
- # prompt format:
80
- # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
81
-
82
- # add prompt to the tokens
83
- tokens_p = entity_prompt + x['tokens'][i]
84
-
85
- # input format:
86
- # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
87
-
88
- # update length of the sequence (add prompt length to the original length)
89
- new_length[i] = new_length[i] + len(entity_prompt)
90
- # update tokens
91
- new_tokens.append(tokens_p)
92
- # store prompt length
93
- all_len_prompt.append(len(entity_prompt))
94
-
95
- # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
96
- max_num_classes = max(num_classes_all)
97
- entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
98
- x['span_mask'].device)
99
- entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
100
- x['span_mask'].device) # [batch_size, max_num_classes]
101
-
102
- # compute all token representations
103
- bert_output = self.token_rep_layer(new_tokens, new_length)
104
- word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
105
- mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
106
-
107
- # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
108
- word_rep = [] # word representation (after [SEP])
109
- mask = [] # mask (after [SEP])
110
- entity_type_rep = [] # entity type representation (before [SEP])
111
- for i in range(len(x['tokens'])):
112
- prompt_entity_length = all_len_prompt[i] # length of prompt for this example
113
- # get word representation (after [SEP])
114
- word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
115
- # get mask (after [SEP])
116
- mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
117
-
118
- # get entity type representation (before [SEP])
119
- entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
120
- entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
121
- entity_type_rep.append(entity_rep)
122
-
123
- # padding for word_rep, mask and entity_type_rep
124
- word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
125
- mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
126
- entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
127
-
128
- # compute span representation
129
- word_rep = self.rnn(word_rep, mask)
130
- span_rep = self.span_rep_layer(word_rep, span_idx)
131
-
132
- # compute final entity type representation (FFN)
133
- entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
134
- num_classes = entity_type_rep.shape[1] # number of entity types
135
-
136
- # similarity score
137
- scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
138
-
139
- return scores, num_classes, entity_type_mask
140
-
141
- def forward(self, x):
142
- # compute span representation
143
- scores, num_classes, entity_type_mask = self.compute_score_train(x)
144
- batch_size = scores.shape[0]
145
-
146
- # loss for filtering classifier
147
- logits_label = scores.view(-1, num_classes)
148
- labels = x["span_label"].view(-1) # (batch_size * num_spans)
149
- mask_label = labels != -1 # (batch_size * num_spans)
150
- labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
151
-
152
- # one-hot encoding
153
- labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
154
- labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
155
- labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
156
- # Shape of labels_one_hot: (batch_size * num_spans, num_classes)
157
-
158
- # compute loss (without reduction)
159
- all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
160
- reduction='none')
161
- # mask loss using entity_type_mask (B, C)
162
- masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
163
- all_losses = masked_loss.view(-1, num_classes)
164
- # expand mask_label to all_losses
165
- mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
166
- # put lower loss for in label_one_hot (2 for positive, 1 for negative)
167
- weight_c = labels_one_hot + 1
168
- # apply mask
169
- all_losses = all_losses * mask_label.float() * weight_c
170
- return all_losses.sum()
171
-
172
- def compute_score_eval(self, x, device):
173
- # check if classes_to_id is dict
174
- assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
175
-
176
- span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
177
-
178
- all_types = list(x['classes_to_id'].keys())
179
- # multiple entity types in all_types. Prompt is appended at the start of tokens
180
- entity_prompt = []
181
-
182
- # add enity types to prompt
183
- for entity_type in all_types:
184
- entity_prompt.append(self.entity_token)
185
- entity_prompt.append(entity_type)
186
-
187
- entity_prompt.append(self.sep_token)
188
-
189
- prompt_entity_length = len(entity_prompt)
190
-
191
- # add prompt
192
- tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
193
- seq_length_p = x['seq_length'] + prompt_entity_length
194
-
195
- out = self.token_rep_layer(tokens_p, seq_length_p)
196
-
197
- word_rep_w_prompt = out["embeddings"]
198
- mask_w_prompt = out["mask"]
199
-
200
- # remove prompt
201
- word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
202
- mask = mask_w_prompt[:, prompt_entity_length:]
203
-
204
- # get_entity_type_rep
205
- entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
206
- # extract [ENT] tokens (which are at even positions in entity_type_rep)
207
- entity_type_rep = entity_type_rep[:, 0::2, :]
208
-
209
- entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
210
-
211
- word_rep = self.rnn(word_rep, mask)
212
-
213
- span_rep = self.span_rep_layer(word_rep, span_idx)
214
-
215
- local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
216
-
217
- return local_scores
218
-
219
- @torch.no_grad()
220
- def predict(self, x, flat_ner=False, threshold=0.5):
221
- self.eval()
222
- local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
223
- spans = []
224
- for i, _ in enumerate(x["tokens"]):
225
- local_i = local_scores[i]
226
- wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
227
- span_i = []
228
- for s, k, c in zip(*wh_i):
229
- if s + k < len(x["tokens"][i]):
230
- span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
231
- span_i = greedy_search(span_i, flat_ner)
232
- spans.append(span_i)
233
- return spans
234
-
235
- def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
236
- tokens = []
237
- start_token_idx_to_text_idx = []
238
- end_token_idx_to_text_idx = []
239
- for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
240
- tokens.append(match.group())
241
- start_token_idx_to_text_idx.append(match.start())
242
- end_token_idx_to_text_idx.append(match.end())
243
-
244
- input_x = {"tokenized_text": tokens, "ner": None}
245
- x = self.collate_fn([input_x], labels)
246
- output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
247
-
248
- entities = []
249
- for start_token_idx, end_token_idx, ent_type in output[0]:
250
- start_text_idx = start_token_idx_to_text_idx[start_token_idx]
251
- end_text_idx = end_token_idx_to_text_idx[end_token_idx]
252
- entities.append({
253
- "start": start_token_idx_to_text_idx[start_token_idx],
254
- "end": end_token_idx_to_text_idx[end_token_idx],
255
- "text": text[start_text_idx:end_text_idx],
256
- "label": ent_type,
257
- })
258
- return entities
259
-
260
- def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
261
- self.eval()
262
- data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
263
- device = next(self.parameters()).device
264
- all_preds = []
265
- all_trues = []
266
- for x in data_loader:
267
- for k, v in x.items():
268
- if isinstance(v, torch.Tensor):
269
- x[k] = v.to(device)
270
- batch_predictions = self.predict(x, flat_ner, threshold)
271
- all_preds.extend(batch_predictions)
272
- all_trues.extend(x["entities"])
273
- evaluator = Evaluator(all_trues, all_preds)
274
- out, f1 = evaluator.evaluate()
275
- return out, f1
276
-
277
- @classmethod
278
- def _from_pretrained(
279
- cls,
280
- *,
281
- model_id: str,
282
- revision: Optional[str],
283
- cache_dir: Optional[Union[str, Path]],
284
- force_download: bool,
285
- proxies: Optional[Dict],
286
- resume_download: bool,
287
- local_files_only: bool,
288
- token: Union[str, bool, None],
289
- map_location: str = "cpu",
290
- strict: bool = False,
291
- **model_kwargs,
292
- ):
293
- # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
294
- filenames = ["gliner_base.pt", "gliner_multi.pt"]
295
- for filename in filenames:
296
- model_file = Path(model_id) / filename
297
- if not model_file.exists():
298
- try:
299
- model_file = hf_hub_download(
300
- repo_id=model_id,
301
- filename=filename,
302
- revision=revision,
303
- cache_dir=cache_dir,
304
- force_download=force_download,
305
- proxies=proxies,
306
- resume_download=resume_download,
307
- token=token,
308
- local_files_only=local_files_only,
309
- )
310
- except HfHubHTTPError:
311
- continue
312
- dict_load = torch.load(model_file, map_location=torch.device(map_location))
313
- config = dict_load["config"]
314
- state_dict = dict_load["model_weights"]
315
- config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
316
- model = cls(config)
317
- model.load_state_dict(state_dict, strict=strict, assign=True)
318
- # Required to update flair's internals as well:
319
- model.to(map_location)
320
- return model
321
-
322
- # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
323
- from train import load_config_as_namespace
324
-
325
- model_file = Path(model_id) / "pytorch_model.bin"
326
- if not model_file.exists():
327
- model_file = hf_hub_download(
328
- repo_id=model_id,
329
- filename="pytorch_model.bin",
330
- revision=revision,
331
- cache_dir=cache_dir,
332
- force_download=force_download,
333
- proxies=proxies,
334
- resume_download=resume_download,
335
- token=token,
336
- local_files_only=local_files_only,
337
- )
338
- config_file = Path(model_id) / "gliner_config.json"
339
- if not config_file.exists():
340
- config_file = hf_hub_download(
341
- repo_id=model_id,
342
- filename="gliner_config.json",
343
- revision=revision,
344
- cache_dir=cache_dir,
345
- force_download=force_download,
346
- proxies=proxies,
347
- resume_download=resume_download,
348
- token=token,
349
- local_files_only=local_files_only,
350
- )
351
- config = load_config_as_namespace(config_file)
352
- model = cls(config)
353
- state_dict = torch.load(model_file, map_location=torch.device(map_location))
354
- model.load_state_dict(state_dict, strict=strict, assign=True)
355
- model.to(map_location)
356
- return model
357
-
358
- def save_pretrained(
359
- self,
360
- save_directory: Union[str, Path],
361
- *,
362
- config: Optional[Union[dict, "DataclassInstance"]] = None,
363
- repo_id: Optional[str] = None,
364
- push_to_hub: bool = False,
365
- **push_to_hub_kwargs,
366
- ) -> Optional[str]:
367
- """
368
- Save weights in local directory.
369
-
370
- Args:
371
- save_directory (`str` or `Path`):
372
- Path to directory in which the model weights and configuration will be saved.
373
- config (`dict` or `DataclassInstance`, *optional*):
374
- Model configuration specified as a key/value dictionary or a dataclass instance.
375
- push_to_hub (`bool`, *optional*, defaults to `False`):
376
- Whether or not to push your model to the Huggingface Hub after saving it.
377
- repo_id (`str`, *optional*):
378
- ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
379
- not provided.
380
- kwargs:
381
- Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
382
- """
383
- save_directory = Path(save_directory)
384
- save_directory.mkdir(parents=True, exist_ok=True)
385
-
386
- # save model weights/files
387
- torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
388
-
389
- # save config (if provided)
390
- if config is None:
391
- config = self.config
392
- if config is not None:
393
- if isinstance(config, argparse.Namespace):
394
- config = vars(config)
395
- (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
396
-
397
- # push to the Hub if required
398
- if push_to_hub:
399
- kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
400
- if config is not None: # kwarg for `push_to_hub`
401
- kwargs["config"] = config
402
- if repo_id is None:
403
- repo_id = save_directory.name # Defaults to `save_directory` name
404
- return self.push_to_hub(repo_id=repo_id, **kwargs)
405
- return None
406
-
407
- def to(self, device):
408
- super().to(device)
409
- import flair
410
-
411
- flair.device = device
412
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/modules/base.py DELETED
@@ -1,150 +0,0 @@
1
- from collections import defaultdict
2
- from typing import List, Tuple, Dict
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn.utils.rnn import pad_sequence
7
- from torch.utils.data import DataLoader
8
- import random
9
-
10
-
11
- class InstructBase(nn.Module):
12
- def __init__(self, config):
13
- super().__init__()
14
- self.max_width = config.max_width
15
- self.base_config = config
16
-
17
- def get_dict(self, spans, classes_to_id):
18
- dict_tag = defaultdict(int)
19
- for span in spans:
20
- if span[2] in classes_to_id:
21
- dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
22
- return dict_tag
23
-
24
- def preprocess_spans(self, tokens, ner, classes_to_id):
25
-
26
- max_len = self.base_config.max_len
27
-
28
- if len(tokens) > max_len:
29
- length = max_len
30
- tokens = tokens[:max_len]
31
- else:
32
- length = len(tokens)
33
-
34
- spans_idx = []
35
- for i in range(length):
36
- spans_idx.extend([(i, i + j) for j in range(self.max_width)])
37
-
38
- dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
39
-
40
- # 0 for null labels
41
- span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
42
- spans_idx = torch.LongTensor(spans_idx)
43
-
44
- # mask for valid spans
45
- valid_span_mask = spans_idx[:, 1] > length - 1
46
-
47
- # mask invalid positions
48
- span_label = span_label.masked_fill(valid_span_mask, -1)
49
-
50
- return {
51
- 'tokens': tokens,
52
- 'span_idx': spans_idx,
53
- 'span_label': span_label,
54
- 'seq_length': length,
55
- 'entities': ner,
56
- }
57
-
58
- def collate_fn(self, batch_list, entity_types=None):
59
- # batch_list: list of dict containing tokens, ner
60
- if entity_types is None:
61
- negs = self.get_negatives(batch_list, 100)
62
- class_to_ids = []
63
- id_to_classes = []
64
- for b in batch_list:
65
- # negs = b["negative"]
66
- random.shuffle(negs)
67
-
68
- # negs = negs[:sampled_neg]
69
- max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
70
-
71
- if max_neg_type_ratio == 0:
72
- # no negatives
73
- neg_type_ratio = 0
74
- else:
75
- neg_type_ratio = random.randint(0, max_neg_type_ratio)
76
-
77
- if neg_type_ratio == 0:
78
- # no negatives
79
- negs_i = []
80
- else:
81
- negs_i = negs[:len(b['ner']) * neg_type_ratio]
82
-
83
- # this is the list of all possible entity types (positive and negative)
84
- types = list(set([el[-1] for el in b['ner']] + negs_i))
85
-
86
- # shuffle (every epoch)
87
- random.shuffle(types)
88
-
89
- if len(types) != 0:
90
- # prob of higher number shoul
91
- # random drop
92
- if self.base_config.random_drop:
93
- num_ents = random.randint(1, len(types))
94
- types = types[:num_ents]
95
-
96
- # maximum number of entities types
97
- types = types[:int(self.base_config.max_types)]
98
-
99
- # supervised training
100
- if "label" in b:
101
- types = sorted(b["label"])
102
-
103
- class_to_id = {k: v for v, k in enumerate(types, start=1)}
104
- id_to_class = {k: v for v, k in class_to_id.items()}
105
- class_to_ids.append(class_to_id)
106
- id_to_classes.append(id_to_class)
107
-
108
- batch = [
109
- self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
110
- ]
111
-
112
- else:
113
- class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
114
- id_to_classes = {k: v for v, k in class_to_ids.items()}
115
- batch = [
116
- self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
117
- ]
118
-
119
- span_idx = pad_sequence(
120
- [b['span_idx'] for b in batch], batch_first=True, padding_value=0
121
- )
122
-
123
- span_label = pad_sequence(
124
- [el['span_label'] for el in batch], batch_first=True, padding_value=-1
125
- )
126
-
127
- return {
128
- 'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
129
- 'span_idx': span_idx,
130
- 'tokens': [el['tokens'] for el in batch],
131
- 'span_mask': span_label != -1,
132
- 'span_label': span_label,
133
- 'entities': [el['entities'] for el in batch],
134
- 'classes_to_id': class_to_ids,
135
- 'id_to_classes': id_to_classes,
136
- }
137
-
138
- @staticmethod
139
- def get_negatives(batch_list, sampled_neg=5):
140
- ent_types = []
141
- for b in batch_list:
142
- types = set([el[-1] for el in b['ner']])
143
- ent_types.extend(list(types))
144
- ent_types = list(set(ent_types))
145
- # sample negatives
146
- random.shuffle(ent_types)
147
- return ent_types[:sampled_neg]
148
-
149
- def create_dataloader(self, data, entity_types=None, **kwargs):
150
- return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/modules/data_proc.py DELETED
@@ -1,73 +0,0 @@
1
- import json
2
- from tqdm import tqdm
3
- # ast.literal_eval
4
- import ast, re
5
-
6
- path = 'train.json'
7
-
8
- with open(path, 'r') as f:
9
- data = json.load(f)
10
-
11
- def tokenize_text(text):
12
- return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
13
-
14
- def extract_entity_spans(entry):
15
- text = ""
16
- len_start = len("What describes ")
17
- len_end = len(" in the text?")
18
- entity_types = []
19
- entity_texts = []
20
-
21
- for c in entry['conversations']:
22
- if c['from'] == 'human' and c['value'].startswith('Text: '):
23
- text = c['value'][len('Text: '):]
24
- tokenized_text = tokenize_text(text)
25
-
26
- if c['from'] == 'human' and c['value'].startswith('What describes '):
27
-
28
- c_type = c['value'][len_start:-len_end]
29
- c_type = c_type.replace(' ', '_')
30
- entity_types.append(c_type)
31
-
32
- elif c['from'] == 'gpt' and c['value'].startswith('['):
33
- if c['value'] == '[]':
34
- entity_types = entity_types[:-1]
35
- continue
36
-
37
- texts_ents = ast.literal_eval(c['value'])
38
- # replace space to _ in texts_ents
39
- entity_texts.extend(texts_ents)
40
- num_repeat = len(texts_ents) - 1
41
- entity_types.extend([entity_types[-1]] * num_repeat)
42
-
43
- entity_spans = []
44
- for j, entity_text in enumerate(entity_texts):
45
- entity_tokens = tokenize_text(entity_text)
46
- matches = []
47
- for i in range(len(tokenized_text) - len(entity_tokens) + 1):
48
- if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
49
- matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
50
- if matches:
51
- entity_spans.extend(matches)
52
-
53
- return entity_spans, tokenized_text
54
-
55
- # Usage:
56
- # Replace 'entry' with the specific entry from your JSON data
57
- entry = data[17818] # For example, taking the first entry
58
- entity_spans, tokenized_text = extract_entity_spans(entry)
59
- print("Entity Spans:", entity_spans)
60
- #print("Tokenized Text:", tokenized_text)
61
-
62
- # create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
63
-
64
- all_data = []
65
-
66
- for entry in tqdm(data):
67
- entity_spans, tokenized_text = extract_entity_spans(entry)
68
- all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
69
-
70
-
71
- with open('train_instruct.json', 'w') as f:
72
- json.dump(all_data, f)
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/modules/evaluator.py DELETED
@@ -1,152 +0,0 @@
1
- from collections import defaultdict
2
-
3
- import numpy as np
4
- import torch
5
- from seqeval.metrics.v1 import _prf_divide
6
-
7
-
8
- def extract_tp_actual_correct(y_true, y_pred):
9
- entities_true = defaultdict(set)
10
- entities_pred = defaultdict(set)
11
-
12
- for type_name, (start, end), idx in y_true:
13
- entities_true[type_name].add((start, end, idx))
14
- for type_name, (start, end), idx in y_pred:
15
- entities_pred[type_name].add((start, end, idx))
16
-
17
- target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
18
-
19
- tp_sum = np.array([], dtype=np.int32)
20
- pred_sum = np.array([], dtype=np.int32)
21
- true_sum = np.array([], dtype=np.int32)
22
- for type_name in target_names:
23
- entities_true_type = entities_true.get(type_name, set())
24
- entities_pred_type = entities_pred.get(type_name, set())
25
- tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
26
- pred_sum = np.append(pred_sum, len(entities_pred_type))
27
- true_sum = np.append(true_sum, len(entities_true_type))
28
-
29
- return pred_sum, tp_sum, true_sum, target_names
30
-
31
-
32
- def flatten_for_eval(y_true, y_pred):
33
- all_true = []
34
- all_pred = []
35
-
36
- for i, (true, pred) in enumerate(zip(y_true, y_pred)):
37
- all_true.extend([t + [i] for t in true])
38
- all_pred.extend([p + [i] for p in pred])
39
-
40
- return all_true, all_pred
41
-
42
-
43
- def compute_prf(y_true, y_pred, average='micro'):
44
- y_true, y_pred = flatten_for_eval(y_true, y_pred)
45
-
46
- pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
47
-
48
- if average == 'micro':
49
- tp_sum = np.array([tp_sum.sum()])
50
- pred_sum = np.array([pred_sum.sum()])
51
- true_sum = np.array([true_sum.sum()])
52
-
53
- precision = _prf_divide(
54
- numerator=tp_sum,
55
- denominator=pred_sum,
56
- metric='precision',
57
- modifier='predicted',
58
- average=average,
59
- warn_for=('precision', 'recall', 'f-score'),
60
- zero_division='warn'
61
- )
62
-
63
- recall = _prf_divide(
64
- numerator=tp_sum,
65
- denominator=true_sum,
66
- metric='recall',
67
- modifier='true',
68
- average=average,
69
- warn_for=('precision', 'recall', 'f-score'),
70
- zero_division='warn'
71
- )
72
-
73
- denominator = precision + recall
74
- denominator[denominator == 0.] = 1
75
- f_score = 2 * (precision * recall) / denominator
76
-
77
- return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
78
-
79
-
80
- class Evaluator:
81
- def __init__(self, all_true, all_outs):
82
- self.all_true = all_true
83
- self.all_outs = all_outs
84
-
85
- def get_entities_fr(self, ents):
86
- all_ents = []
87
- for s, e, lab in ents:
88
- all_ents.append([lab, (s, e)])
89
- return all_ents
90
-
91
- def transform_data(self):
92
- all_true_ent = []
93
- all_outs_ent = []
94
- for i, j in zip(self.all_true, self.all_outs):
95
- e = self.get_entities_fr(i)
96
- all_true_ent.append(e)
97
- e = self.get_entities_fr(j)
98
- all_outs_ent.append(e)
99
- return all_true_ent, all_outs_ent
100
-
101
- @torch.no_grad()
102
- def evaluate(self):
103
- all_true_typed, all_outs_typed = self.transform_data()
104
- precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
105
- output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
106
- return output_str, f1
107
-
108
-
109
- def is_nested(idx1, idx2):
110
- # Return True if idx2 is nested inside idx1 or vice versa
111
- return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
112
-
113
-
114
- def has_overlapping(idx1, idx2):
115
- overlapping = True
116
- if idx1[:2] == idx2[:2]:
117
- return overlapping
118
- if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
119
- overlapping = False
120
- return overlapping
121
-
122
-
123
- def has_overlapping_nested(idx1, idx2):
124
- # Return True if idx1 and idx2 overlap, but neither is nested inside the other
125
- if idx1[:2] == idx2[:2]:
126
- return True
127
- if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
128
- return False
129
- else:
130
- return True
131
-
132
-
133
- def greedy_search(spans, flat_ner=True): # start, end, class, score
134
-
135
- if flat_ner:
136
- has_ov = has_overlapping
137
- else:
138
- has_ov = has_overlapping_nested
139
-
140
- new_list = []
141
- span_prob = sorted(spans, key=lambda x: -x[-1])
142
- for i in range(len(spans)):
143
- b = span_prob[i]
144
- flag = False
145
- for new in new_list:
146
- if has_ov(b[:-1], new):
147
- flag = True
148
- break
149
- if not flag:
150
- new_list.append(b[:-1])
151
- new_list = sorted(new_list, key=lambda x: x[0])
152
- return new_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/modules/layers.py DELETED
@@ -1,28 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
-
6
-
7
- class LstmSeq2SeqEncoder(nn.Module):
8
- def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
9
- super(LstmSeq2SeqEncoder, self).__init__()
10
- self.lstm = nn.LSTM(input_size=input_size,
11
- hidden_size=hidden_size,
12
- num_layers=num_layers,
13
- dropout=dropout,
14
- bidirectional=bidirectional,
15
- batch_first=True)
16
-
17
- def forward(self, x, mask, hidden=None):
18
- # Packing the input sequence
19
- lengths = mask.sum(dim=1).cpu()
20
- packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
21
-
22
- # Passing packed sequence through LSTM
23
- packed_output, hidden = self.lstm(packed_x, hidden)
24
-
25
- # Unpacking the output sequence
26
- output, _ = pad_packed_sequence(packed_output, batch_first=True)
27
-
28
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/modules/run_evaluation.py DELETED
@@ -1,188 +0,0 @@
1
- import glob
2
- import json
3
- import os
4
- import os
5
-
6
- import torch
7
- from tqdm import tqdm
8
- import random
9
-
10
-
11
- def open_content(path):
12
- paths = glob.glob(os.path.join(path, "*.json"))
13
- train, dev, test, labels = None, None, None, None
14
- for p in paths:
15
- if "train" in p:
16
- with open(p, "r") as f:
17
- train = json.load(f)
18
- elif "dev" in p:
19
- with open(p, "r") as f:
20
- dev = json.load(f)
21
- elif "test" in p:
22
- with open(p, "r") as f:
23
- test = json.load(f)
24
- elif "labels" in p:
25
- with open(p, "r") as f:
26
- labels = json.load(f)
27
- return train, dev, test, labels
28
-
29
-
30
- def process(data):
31
- words = data['sentence'].split()
32
- entities = [] # List of entities (start, end, type)
33
-
34
- for entity in data['entities']:
35
- start_char, end_char = entity['pos']
36
-
37
- # Initialize variables to keep track of word positions
38
- start_word = None
39
- end_word = None
40
-
41
- # Iterate through words and find the word positions
42
- char_count = 0
43
- for i, word in enumerate(words):
44
- word_length = len(word)
45
- if char_count == start_char:
46
- start_word = i
47
- if char_count + word_length == end_char:
48
- end_word = i
49
- break
50
- char_count += word_length + 1 # Add 1 for the space
51
-
52
- # Append the word positions to the list
53
- entities.append((start_word, end_word, entity['type']))
54
-
55
- # Create a list of word positions for each entity
56
- sample = {
57
- "tokenized_text": words,
58
- "ner": entities
59
- }
60
-
61
- return sample
62
-
63
-
64
- # create dataset
65
- def create_dataset(path):
66
- train, dev, test, labels = open_content(path)
67
- train_dataset = []
68
- dev_dataset = []
69
- test_dataset = []
70
- for data in train:
71
- train_dataset.append(process(data))
72
- for data in dev:
73
- dev_dataset.append(process(data))
74
- for data in test:
75
- test_dataset.append(process(data))
76
- return train_dataset, dev_dataset, test_dataset, labels
77
-
78
-
79
- @torch.no_grad()
80
- def get_for_one_path(path, model):
81
- # load the dataset
82
- _, _, test_dataset, entity_types = create_dataset(path)
83
-
84
- data_name = path.split("/")[-1] # get the name of the dataset
85
-
86
- # check if the dataset is flat_ner
87
- flat_ner = True
88
- if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
89
- flat_ner = False
90
-
91
- # evaluate the model
92
- results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
93
- entity_types=entity_types)
94
- return data_name, results, f1
95
-
96
-
97
- def get_for_all_path(model, steps, log_dir, data_paths):
98
- all_paths = glob.glob(f"{data_paths}/*")
99
-
100
- all_paths = sorted(all_paths)
101
-
102
- # move the model to the device
103
- device = next(model.parameters()).device
104
- model.to(device)
105
- # set the model to eval mode
106
- model.eval()
107
-
108
- # log the results
109
- save_path = os.path.join(log_dir, "results.txt")
110
-
111
- with open(save_path, "a") as f:
112
- f.write("##############################################\n")
113
- # write step
114
- f.write("step: " + str(steps) + "\n")
115
-
116
- zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
117
- "CrossNER_politics", "CrossNER_science"]
118
-
119
- zero_shot_benc_results = {}
120
- all_results = {} # without crossNER
121
-
122
- for p in tqdm(all_paths):
123
- if "sample_" not in p:
124
- data_name, results, f1 = get_for_one_path(p, model)
125
- # write to file
126
- with open(save_path, "a") as f:
127
- f.write(data_name + "\n")
128
- f.write(str(results) + "\n")
129
-
130
- if data_name in zero_shot_benc:
131
- zero_shot_benc_results[data_name] = f1
132
- else:
133
- all_results[data_name] = f1
134
-
135
- avg_all = sum(all_results.values()) / len(all_results)
136
- avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
137
-
138
- save_path_table = os.path.join(log_dir, "tables.txt")
139
-
140
- # results for all datasets except crossNER
141
- table_bench_all = ""
142
- for k, v in all_results.items():
143
- table_bench_all += f"{k:20}: {v:.1%}\n"
144
- # (20 size aswell for average i.e. :20)
145
- table_bench_all += f"{'Average':20}: {avg_all:.1%}"
146
-
147
- # results for zero-shot benchmark
148
- table_bench_zeroshot = ""
149
- for k, v in zero_shot_benc_results.items():
150
- table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
151
- table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
152
-
153
- # write to file
154
- with open(save_path_table, "a") as f:
155
- f.write("##############################################\n")
156
- f.write("step: " + str(steps) + "\n")
157
- f.write("Table for all datasets except crossNER\n")
158
- f.write(table_bench_all + "\n\n")
159
- f.write("Table for zero-shot benchmark\n")
160
- f.write(table_bench_zeroshot + "\n")
161
- f.write("##############################################\n\n")
162
-
163
-
164
- def sample_train_data(data_paths, sample_size=10000):
165
- all_paths = glob.glob(f"{data_paths}/*")
166
-
167
- all_paths = sorted(all_paths)
168
-
169
- # to exclude the zero-shot benchmark datasets
170
- zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
171
- "CrossNER_politics", "CrossNER_science", "ACE 2004"]
172
-
173
- new_train = []
174
- # take 10k samples from each dataset
175
- for p in tqdm(all_paths):
176
- if any([i in p for i in zero_shot_benc]):
177
- continue
178
- train, dev, test, labels = create_dataset(p)
179
-
180
- # add label key to the train data
181
- for i in range(len(train)):
182
- train[i]["label"] = labels
183
-
184
- random.shuffle(train)
185
- train = train[:sample_size]
186
- new_train.extend(train)
187
-
188
- return new_train
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/modules/span_rep.py DELETED
@@ -1,326 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
-
6
- class SpanQuery(nn.Module):
7
-
8
- def __init__(self, hidden_size, max_width, trainable=True):
9
- super().__init__()
10
-
11
- self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
12
-
13
- nn.init.uniform_(self.query_seg, a=-1, b=1)
14
-
15
- if not trainable:
16
- self.query_seg.requires_grad = False
17
-
18
- self.project = nn.Sequential(
19
- nn.Linear(hidden_size, hidden_size),
20
- nn.ReLU()
21
- )
22
-
23
- def forward(self, h, *args):
24
- # h of shape [B, L, D]
25
- # query_seg of shape [D, max_width]
26
-
27
- span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
28
-
29
- return self.project(span_rep)
30
-
31
-
32
- class SpanMLP(nn.Module):
33
-
34
- def __init__(self, hidden_size, max_width):
35
- super().__init__()
36
-
37
- self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
38
-
39
- def forward(self, h, *args):
40
- # h of shape [B, L, D]
41
- # query_seg of shape [D, max_width]
42
-
43
- B, L, D = h.size()
44
-
45
- span_rep = self.mlp(h)
46
-
47
- span_rep = span_rep.view(B, L, -1, D)
48
-
49
- return span_rep.relu()
50
-
51
-
52
- class SpanCAT(nn.Module):
53
-
54
- def __init__(self, hidden_size, max_width):
55
- super().__init__()
56
-
57
- self.max_width = max_width
58
-
59
- self.query_seg = nn.Parameter(torch.randn(128, max_width))
60
-
61
- self.project = nn.Sequential(
62
- nn.Linear(hidden_size + 128, hidden_size),
63
- nn.ReLU()
64
- )
65
-
66
- def forward(self, h, *args):
67
- # h of shape [B, L, D]
68
- # query_seg of shape [D, max_width]
69
-
70
- B, L, D = h.size()
71
-
72
- h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
73
-
74
- q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
75
-
76
- span_rep = torch.cat([h, q], dim=-1)
77
-
78
- span_rep = self.project(span_rep)
79
-
80
- return span_rep
81
-
82
-
83
- class SpanConvBlock(nn.Module):
84
- def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
85
- super().__init__()
86
-
87
- if span_mode == 'conv_conv':
88
- self.conv = nn.Conv1d(hidden_size, hidden_size,
89
- kernel_size=kernel_size)
90
-
91
- # initialize the weights
92
- nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
93
-
94
- elif span_mode == 'conv_max':
95
- self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
96
- elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
97
- self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
98
-
99
- self.span_mode = span_mode
100
-
101
- self.pad = kernel_size - 1
102
-
103
- def forward(self, x):
104
-
105
- x = torch.einsum('bld->bdl', x)
106
-
107
- if self.pad > 0:
108
- x = F.pad(x, (0, self.pad), "constant", 0)
109
-
110
- x = self.conv(x)
111
-
112
- if self.span_mode == "conv_sum":
113
- x = x * (self.pad + 1)
114
-
115
- return torch.einsum('bdl->bld', x)
116
-
117
-
118
- class SpanConv(nn.Module):
119
- def __init__(self, hidden_size, max_width, span_mode):
120
- super().__init__()
121
-
122
- kernels = [i + 2 for i in range(max_width - 1)]
123
-
124
- self.convs = nn.ModuleList()
125
-
126
- for kernel in kernels:
127
- self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
128
-
129
- self.project = nn.Sequential(
130
- nn.ReLU(),
131
- nn.Linear(hidden_size, hidden_size)
132
- )
133
-
134
- def forward(self, x, *args):
135
-
136
- span_reps = [x]
137
-
138
- for conv in self.convs:
139
- h = conv(x)
140
- span_reps.append(h)
141
-
142
- span_reps = torch.stack(span_reps, dim=-2)
143
-
144
- return self.project(span_reps)
145
-
146
-
147
- class SpanEndpointsBlock(nn.Module):
148
- def __init__(self, kernel_size):
149
- super().__init__()
150
-
151
- self.kernel_size = kernel_size
152
-
153
- def forward(self, x):
154
- B, L, D = x.size()
155
-
156
- span_idx = torch.LongTensor(
157
- [[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
158
-
159
- x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
160
-
161
- # endrep
162
- start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
163
-
164
- start_end_rep = start_end_rep.view(B, L, 2, D)
165
-
166
- return start_end_rep
167
-
168
-
169
- class ConvShare(nn.Module):
170
- def __init__(self, hidden_size, max_width):
171
- super().__init__()
172
-
173
- self.max_width = max_width
174
-
175
- self.conv_weigth = nn.Parameter(
176
- torch.randn(hidden_size, hidden_size, max_width))
177
-
178
- nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
179
-
180
- self.project = nn.Sequential(
181
- nn.ReLU(),
182
- nn.Linear(hidden_size, hidden_size)
183
- )
184
-
185
- def forward(self, x, *args):
186
- span_reps = []
187
-
188
- x = torch.einsum('bld->bdl', x)
189
-
190
- for i in range(self.max_width):
191
- pad = i
192
- x_i = F.pad(x, (0, pad), "constant", 0)
193
- conv_w = self.conv_weigth[:, :, :i + 1]
194
- out_i = F.conv1d(x_i, conv_w)
195
- span_reps.append(out_i.transpose(-1, -2))
196
-
197
- out = torch.stack(span_reps, dim=-2)
198
-
199
- return self.project(out)
200
-
201
-
202
- def extract_elements(sequence, indices):
203
- B, L, D = sequence.shape
204
- K = indices.shape[1]
205
-
206
- # Expand indices to [B, K, D]
207
- expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
208
-
209
- # Gather the elements
210
- extracted_elements = torch.gather(sequence, 1, expanded_indices)
211
-
212
- return extracted_elements
213
-
214
-
215
- class SpanMarker(nn.Module):
216
-
217
- def __init__(self, hidden_size, max_width, dropout=0.4):
218
- super().__init__()
219
-
220
- self.max_width = max_width
221
-
222
- self.project_start = nn.Sequential(
223
- nn.Linear(hidden_size, hidden_size * 2, bias=True),
224
- nn.ReLU(),
225
- nn.Dropout(dropout),
226
- nn.Linear(hidden_size * 2, hidden_size, bias=True),
227
- )
228
-
229
- self.project_end = nn.Sequential(
230
- nn.Linear(hidden_size, hidden_size * 2, bias=True),
231
- nn.ReLU(),
232
- nn.Dropout(dropout),
233
- nn.Linear(hidden_size * 2, hidden_size, bias=True),
234
- )
235
-
236
- self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
237
-
238
- def forward(self, h, span_idx):
239
- # h of shape [B, L, D]
240
- # query_seg of shape [D, max_width]
241
-
242
- B, L, D = h.size()
243
-
244
- # project start and end
245
- start_rep = self.project_start(h)
246
- end_rep = self.project_end(h)
247
-
248
- start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
249
- end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
250
-
251
- # concat start and end
252
- cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
253
-
254
- # project
255
- cat = self.out_project(cat)
256
-
257
- # reshape
258
- return cat.view(B, L, self.max_width, D)
259
-
260
-
261
- class ConvShareV2(nn.Module):
262
- def __init__(self, hidden_size, max_width):
263
- super().__init__()
264
-
265
- self.max_width = max_width
266
-
267
- self.conv_weigth = nn.Parameter(
268
- torch.randn(hidden_size, hidden_size, max_width)
269
- )
270
-
271
- nn.init.xavier_normal_(self.conv_weigth)
272
-
273
- def forward(self, x, *args):
274
- span_reps = []
275
-
276
- x = torch.einsum('bld->bdl', x)
277
-
278
- for i in range(self.max_width):
279
- pad = i
280
- x_i = F.pad(x, (0, pad), "constant", 0)
281
- conv_w = self.conv_weigth[:, :, :i + 1]
282
- out_i = F.conv1d(x_i, conv_w)
283
- span_reps.append(out_i.transpose(-1, -2))
284
-
285
- out = torch.stack(span_reps, dim=-2)
286
-
287
- return out
288
-
289
-
290
- class SpanRepLayer(nn.Module):
291
- """
292
- Various span representation approaches
293
- """
294
-
295
- def __init__(self, hidden_size, max_width, span_mode, **kwargs):
296
- super().__init__()
297
-
298
- if span_mode == 'marker':
299
- self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
300
- elif span_mode == 'query':
301
- self.span_rep_layer = SpanQuery(
302
- hidden_size, max_width, trainable=True)
303
- elif span_mode == 'mlp':
304
- self.span_rep_layer = SpanMLP(hidden_size, max_width)
305
- elif span_mode == 'cat':
306
- self.span_rep_layer = SpanCAT(hidden_size, max_width)
307
- elif span_mode == 'conv_conv':
308
- self.span_rep_layer = SpanConv(
309
- hidden_size, max_width, span_mode='conv_conv')
310
- elif span_mode == 'conv_max':
311
- self.span_rep_layer = SpanConv(
312
- hidden_size, max_width, span_mode='conv_max')
313
- elif span_mode == 'conv_mean':
314
- self.span_rep_layer = SpanConv(
315
- hidden_size, max_width, span_mode='conv_mean')
316
- elif span_mode == 'conv_sum':
317
- self.span_rep_layer = SpanConv(
318
- hidden_size, max_width, span_mode='conv_sum')
319
- elif span_mode == 'conv_share':
320
- self.span_rep_layer = ConvShare(hidden_size, max_width)
321
- else:
322
- raise ValueError(f'Unknown span mode {span_mode}')
323
-
324
- def forward(self, x, *args):
325
-
326
- return self.span_rep_layer(x, *args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/modules/token_rep.py DELETED
@@ -1,54 +0,0 @@
1
- from typing import List
2
-
3
- import torch
4
- from flair.data import Sentence
5
- from flair.embeddings import TransformerWordEmbeddings
6
- from torch import nn
7
- from torch.nn.utils.rnn import pad_sequence
8
-
9
-
10
- # flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
11
-
12
-
13
- class TokenRepLayer(nn.Module):
14
- def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
15
- hidden_size: int = 768,
16
- add_tokens=["[SEP]", "[ENT]"]
17
- ):
18
- super().__init__()
19
-
20
- self.bert_layer = TransformerWordEmbeddings(
21
- model_name,
22
- fine_tune=fine_tune,
23
- subtoken_pooling=subtoken_pooling,
24
- allow_long_sentences=True
25
- )
26
-
27
- # add tokens to vocabulary
28
- self.bert_layer.tokenizer.add_tokens(add_tokens)
29
-
30
- # resize token embeddings
31
- self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
32
-
33
- bert_hidden_size = self.bert_layer.embedding_length
34
-
35
- if hidden_size != bert_hidden_size:
36
- self.projection = nn.Linear(bert_hidden_size, hidden_size)
37
-
38
- def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
39
- token_embeddings = self.compute_word_embedding(tokens)
40
-
41
- if hasattr(self, "projection"):
42
- token_embeddings = self.projection(token_embeddings)
43
-
44
- B = len(lengths)
45
- max_length = lengths.max()
46
- mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
47
- token_embeddings.device).long()
48
- return {"embeddings": token_embeddings, "mask": mask}
49
-
50
- def compute_word_embedding(self, tokens):
51
- sentences = [Sentence(i) for i in tokens]
52
- self.bert_layer.embed(sentences)
53
- token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
54
- return token_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- torch
2
- transformers
3
- huggingface_hub
4
- flair
5
- seqeval
6
- tqdm
 
 
 
 
 
 
 
GLiNER/save_load.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
- from model import GLiNER
3
-
4
-
5
- def save_model(current_model, path):
6
- config = current_model.config
7
- dict_save = {"model_weights": current_model.state_dict(), "config": config}
8
- torch.save(dict_save, path)
9
-
10
-
11
- def load_model(path, model_name=None, device=None):
12
- dict_load = torch.load(path, map_location=torch.device('cpu'))
13
- config = dict_load["config"]
14
-
15
- if model_name is not None:
16
- config.model_name = model_name
17
-
18
- loaded_model = GLiNER(config)
19
- loaded_model.load_state_dict(dict_load["model_weights"])
20
- return loaded_model.to(device) if device is not None else loaded_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GLiNER/train.py DELETED
@@ -1,131 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
- import yaml
6
- from tqdm import tqdm
7
- from transformers import get_cosine_schedule_with_warmup
8
-
9
- # from model_nested import NerFilteredSemiCRF
10
- from model import GLiNER
11
- from modules.run_evaluation import get_for_all_path, sample_train_data
12
- from save_load import save_model, load_model
13
- import json
14
-
15
-
16
- # train function
17
- def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
18
- train_batch_size=8, device='cuda'):
19
- model.train()
20
-
21
- # initialize data loaders
22
- train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
23
-
24
- pbar = tqdm(range(num_steps))
25
-
26
- if warmup_ratio < 1:
27
- num_warmup_steps = int(num_steps * warmup_ratio)
28
- else:
29
- num_warmup_steps = int(warmup_ratio)
30
-
31
- scheduler = get_cosine_schedule_with_warmup(
32
- optimizer,
33
- num_warmup_steps=num_warmup_steps,
34
- num_training_steps=num_steps
35
- )
36
-
37
- iter_train_loader = iter(train_loader)
38
-
39
- for step in pbar:
40
- try:
41
- x = next(iter_train_loader)
42
- except StopIteration:
43
- iter_train_loader = iter(train_loader)
44
- x = next(iter_train_loader)
45
-
46
- for k, v in x.items():
47
- if isinstance(v, torch.Tensor):
48
- x[k] = v.to(device)
49
-
50
- try:
51
- loss = model(x) # Forward pass
52
- except:
53
- continue
54
-
55
- # check if loss is nan
56
- if torch.isnan(loss):
57
- continue
58
-
59
- loss.backward() # Compute gradients
60
- optimizer.step() # Update parameters
61
- scheduler.step() # Update learning rate schedule
62
- optimizer.zero_grad() # Reset gradients
63
-
64
- description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
65
-
66
- if (step + 1) % eval_every == 0:
67
- current_path = os.path.join(log_dir, f'model_{step + 1}')
68
- save_model(model, current_path)
69
- #val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
70
- #get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
71
-
72
- model.train()
73
-
74
- pbar.set_description(description)
75
-
76
-
77
- def create_parser():
78
- parser = argparse.ArgumentParser(description="Span-based NER")
79
- parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
80
- parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
81
- return parser
82
-
83
-
84
- def load_config_as_namespace(config_file):
85
- with open(config_file, 'r') as f:
86
- config_dict = yaml.safe_load(f)
87
- return argparse.Namespace(**config_dict)
88
-
89
-
90
- if __name__ == "__main__":
91
- # parse args
92
- parser = create_parser()
93
- args = parser.parse_args()
94
-
95
- # load config
96
- config = load_config_as_namespace(args.config)
97
-
98
- config.log_dir = args.log_dir
99
-
100
- try:
101
- with open(config.train_data, 'r') as f:
102
- data = json.load(f)
103
- except:
104
- data = sample_train_data(config.train_data, 10000)
105
-
106
- if config.prev_path != "none":
107
- model = load_model(config.prev_path)
108
- model.config = config
109
- else:
110
- model = GLiNER(config)
111
-
112
- if torch.cuda.is_available():
113
- model = model.cuda()
114
-
115
- lr_encoder = float(config.lr_encoder)
116
- lr_others = float(config.lr_others)
117
-
118
- optimizer = torch.optim.AdamW([
119
- # encoder
120
- {'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
121
- {'params': model.rnn.parameters(), 'lr': lr_others},
122
- # projection layers
123
- {'params': model.span_rep_layer.parameters(), 'lr': lr_others},
124
- {'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
125
- ])
126
-
127
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
128
-
129
- train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
130
- log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
131
- device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,8 +1,5 @@
1
  from typing import Dict, Union
2
- import sys
3
-
4
- sys.path.extend(["./GLiNER"])
5
- from GLiNER.model import GLiNER
6
  import gradio as gr
7
 
8
  model = GLiNER.from_pretrained("urchade/gliner_base")
@@ -110,20 +107,18 @@ with gr.Blocks(title="GLiNER-base") as demo:
110
  gr.Markdown(
111
  """
112
  ## Installation
113
- To use this model, you must download the GLiNER repository and install its dependencies:
114
  ```
115
- !git clone https://github.com/urchade/GLiNER.git
116
- %cd GLiNER
117
- !pip install -r requirements.txt
118
  ```
119
 
120
  ## Usage
121
- Once you've downloaded the GLiNER repository, you can import the GLiNER class from the `model` file. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
122
  """
123
  )
124
  gr.Code(
125
  '''
126
- from model import GLiNER
127
 
128
  model = GLiNER.from_pretrained("urchade/gliner_base")
129
 
 
1
  from typing import Dict, Union
2
+ from gliner import GLiNER
 
 
 
3
  import gradio as gr
4
 
5
  model = GLiNER.from_pretrained("urchade/gliner_base")
 
107
  gr.Markdown(
108
  """
109
  ## Installation
110
+ To use this model, you must install the GLiNER Python library:
111
  ```
112
+ !pip install gliner
 
 
113
  ```
114
 
115
  ## Usage
116
+ Once you've downloaded the GLiNER library, you can import the GLiNER class. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
117
  """
118
  )
119
  gr.Code(
120
  '''
121
+ from gliner import GLiNER
122
 
123
  model = GLiNER.from_pretrained("urchade/gliner_base")
124
 
requirements.txt CHANGED
@@ -1,6 +1 @@
1
- torch
2
- transformers
3
- huggingface_hub
4
- flair
5
- seqeval
6
- tqdm
 
1
+ gliner