import copy from functools import partial from typing import Callable, Iterable, List, Optional, Tuple, Union, Dict, Any import murmurhash from spacy.language import Language from spacy.tokens.doc import SetEntsDefault # type: ignore from spacy.training import Example from spacy.util import filter_spans from prodigy.components.db import connect from prodigy.components.decorators import support_both_streams from prodigy.components.filters import filter_seen_before from prodigy.components.preprocess import ( add_annot_name, add_tokens, add_view_id, make_ner_suggestions, make_raw_doc, resolve_labels, split_sentences, ) from prodigy.components.sorters import prefer_uncertain from prodigy.components.source import GeneratorSource from prodigy.components.stream import Stream, get_stream, load_noop from prodigy.core import Arg, recipe from prodigy.errors import RecipeError from prodigy.models.matcher import PatternMatcher from prodigy.models.ner import EntityRecognizerModel, ensure_sentencizer from prodigy.protocols import ControllerComponentsDict from prodigy.types import ( ExistingFilePath, LabelsType, SourceType, StreamType, TaskType, ) from prodigy.util import ( ANNOTATOR_ID_ATTR, BINARY_ATTR, INPUT_HASH_ATTR, TASK_HASH_ATTR, combine_models, copy_nlp, get_pipe_labels, log, msg, set_hashes, ) def modify_spans(document): # Modify the 'spans' key to be an empty list document['spans'] = [] return document def spans_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool: """Checks if two spans are equal""" return s1["start"] == s2["start"] and s1["end"] == s2["end"] def labels_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool: """Checks if two spans have the same label""" return s1["label"] == s2["label"] def ensure_span_text(eg: TaskType) -> TaskType: """Ensure that all spans have a text attribute""" for span in eg.get("spans", []): if "text" not in span: span["text"] = eg["text"][span["start"] : span["end"]] return eg def validate_answer(answer: TaskType, *, known_answers_map: Dict[int, TaskType]): """Validate the answer against the known answers""" known_answer = known_answers_map.get(answer[INPUT_HASH_ATTR]) if known_answer is None: print(f"Skipping validation for answer {answer[INPUT_HASH_ATTR]}, no known answer found to validate against.") return known_answer = ensure_span_text(known_answer) errors = [] known_spans = known_answer.get("spans", []) answer_spans = answer.get("spans", []) explanation_label = known_answer.get("meta", {}).get("explanation_label") explanation_boundaries = known_answer.get("meta", {}).get( "explanation_boundaries" ) if not explanation_boundaries: explanation_boundaries = ( "No explanation boundaries" ) if len(known_spans) > len(answer_spans): errors.append( "You noted fewer entities than expected for this answer. All mentions must be annotated" ) elif len(known_spans) < len(answer_spans): errors.append( "You noted more entities than expected for this answer." ) if not known_spans: # For cases where no annotations are expected errors.append(explanation_label) for known_span, span in zip(known_spans, answer_spans): if not labels_equal(known_span, span): # label error errors.append(explanation_label) continue if not spans_equal(known_span, span): # boundary error errors.append(explanation_boundaries) continue if len(errors) > 0: error_msg = "\n".join(errors) error_msg += "\n\nExpected annotations:" if known_spans: expected_spans = [ f'[{s["text"]}]: {s["label"]}' for s in known_spans ] if expected_spans: error_msg += "\n" for span_msg in expected_spans: error_msg += span_msg + "\n" else: error_msg += "\n\nNone." raise ValueError(error_msg) @recipe( "ner.qa.manual", # fmt: off dataset=Arg(help="Dataset to save annotations to"), nlp=Arg(help="Loadable spaCy pipeline for tokenization or blank:lang (e.g. blank:en)"), source=Arg(help="Data to annotate (file path or '-' to read from standard input)"), loader=Arg("--loader", "-lo", help="Loader (guessed from file extension if not set)"), label=Arg("--label", "-l", help="Comma-separated label(s) to annotate or text file with one label per line"), patterns=Arg("--patterns", "-pt", help="Path to match patterns file"), exclude=Arg("--exclude", "-e", help="Comma-separated list of dataset IDs whose annotations to exclude"), highlight_chars=Arg("--highlight-chars", "-C", help="Allow highlighting individual characters instead of tokens"), # fmt: on ) def manual( dataset: str, nlp: Language, source: SourceType, loader: Optional[str] = None, label: Optional[LabelsType] = None, patterns: Optional[ExistingFilePath] = None, exclude: List[str] = [], highlight_chars: bool = False, ) -> ControllerComponentsDict: """ Mark spans by token. Requires only a tokenizer and no entity recognizer, and doesn't do any active learning. If patterns are provided, their matches are highlighted in the example, if available. The recipe will present all examples in order, so even examples without matches are shown. If character highlighting is enabled, no "tokens" are saved to the database. """ log("RECIPE: Starting recipe ner.manual", locals()) labels = get_pipe_labels(label, nlp.pipe_labels.get("ner", [])) stream = get_stream( source, loader=loader, rehash=True, dedup=True, input_key="text", is_binary=False, ) if patterns is not None: pattern_matcher = PatternMatcher(nlp, combine_matches=True, all_examples=True) pattern_matcher = pattern_matcher.from_disk(patterns) stream.apply(lambda examples: (eg for _, eg in pattern_matcher(examples))) # Add "tokens" key to the tasks, either with words or characters stream.apply(lambda examples: (modify_spans(eg) for eg in examples)) exclude_names = [ds.name for ds in exclude] if exclude is not None else None known_answers = get_stream( source, loader=loader, rehash=True, dedup=True, input_key="text", is_binary=False, ) known_answers_map = {eg[INPUT_HASH_ATTR]: eg for eg in known_answers} return { "view_id": "ner_manual", "dataset": dataset, "stream": [_ for _ in stream], "exclude": exclude_names, "validate_answer": partial(validate_answer, known_answers_map=known_answers_map), "config": { "lang": nlp.lang, "labels": labels, "exclude_by": "input", "ner_manual_highlight_chars": highlight_chars, }, } @support_both_streams(stream_arg="stream") def preprocess_stream( stream: StreamType, nlp: Language, *, labels: Optional[List[str]], unsegmented: bool, set_annotations: bool = True, ) -> StreamType: if not unsegmented: stream = split_sentences(nlp, stream) # type: ignore stream = add_tokens(nlp, stream) # type: ignore if set_annotations: spacy_model = f"{nlp.meta['lang']}_{nlp.meta['name']}" # Add a 'spans' key to each example, with predicted entities texts = ((eg["text"], eg) for eg in stream) for doc, eg in nlp.pipe(texts, as_tuples=True, batch_size=10): task = copy.deepcopy(eg) spans = [] for ent in doc.ents: if labels and ent.label_ not in labels: continue spans.append(ent) for span in eg.get("spans", []): spans.append(doc.char_span(span["start"], span["end"], span["label"])) spans = filter_spans(spans) span_dicts = [] for ent in spans: span_dicts.append( { "token_start": ent.start, "token_end": ent.end - 1, "start": ent.start_char, "end": ent.end_char, "text": ent.text, "label": ent.label_, "source": spacy_model, "input_hash": eg[INPUT_HASH_ATTR], } ) task["spans"] = span_dicts task[BINARY_ATTR] = False task = set_hashes(task) yield task else: yield from stream def get_ner_labels( nlp: Language, *, label: Optional[List[str]], component: str = "ner" ) -> Tuple[List[str], bool]: model_labels = nlp.pipe_labels.get(component, []) labels = get_pipe_labels(label, model_labels) # Check if we're annotating all labels present in the model or a subset no_missing = len(set(labels).intersection(set(model_labels))) == len(model_labels) return labels, no_missing def get_update(nlp: Language, *, no_missing: bool) -> Callable[[List[TaskType]], None]: def update(answers: List[TaskType]) -> None: log(f"RECIPE: Updating model with {len(answers)} answers") examples = [] for eg in answers: if eg["answer"] == "accept": doc = make_raw_doc(nlp, eg) ref = make_raw_doc(nlp, eg) spans = [ doc.char_span(span["start"], span["end"], label=span["label"]) for span in eg.get("spans", []) ] value = SetEntsDefault.outside if no_missing else SetEntsDefault.missing ref.set_ents(spans, default=value) examples.append(Example(doc, ref)) nlp.update(examples) return update