Gosse Minnema
Re-enable LOME
2890e34
from abc import ABC, abstractmethod
from typing import *
import torch
from allennlp.common import Registrable
from allennlp.modules.span_extractors import SpanExtractor
class SpanFinder(Registrable, ABC, torch.nn.Module):
"""
Model the probability p(child_span | parent_span [, parent_label])
It's optional to model parent_label, since in some cases we may want the parameters to be shared across
different tasks, where we may have similar span semantics but different label space.
"""
def __init__(
self,
no_label: bool = True,
):
"""
:param no_label: If True, will not use input labels as features and use all 0 vector instead.
"""
super().__init__()
self._no_label = no_label
@abstractmethod
def forward(
self,
token_vec: torch.Tensor,
token_mask: torch.Tensor,
span_vec: torch.Tensor,
span_mask: Optional[torch.Tensor] = None, # Do not need to provide
span_labels: Optional[torch.Tensor] = None, # Do not need to provide
parent_indices: Optional[torch.Tensor] = None, # Do not need to provide
parent_mask: Optional[torch.Tensor] = None,
bio_seqs: Optional[torch.Tensor] = None,
prediction: bool = False,
**extra
) -> Dict[str, torch.Tensor]:
"""
Return training loss and predictions.
:param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim]
:param token_mask: True for non-padding tokens.
:param span_vec: Vector representation of spans. Shape [batch, span, token_dim]
:param span_mask: True for non-padding spans. Shape [batch, span]
:param span_labels: The labels of spans. Shape [batch, span]
:param parent_indices: Parent indices of spans. Shape [batch, span]
:param parent_mask: True for parent spans. Shape [batch, span]
:param prediction: If True, no loss will be return & no metrics will be updated.
:param bio_seqs: BIO sequences. Shape [batch, parent, token, 3]
:return:
loss: Training loss
prediction: Shape [batch, span]. True for positive predictions.
"""
raise NotImplementedError
@abstractmethod
def inference_forward_handler(
self,
token_vec: torch.Tensor,
token_mask: torch.Tensor,
span_extractor: SpanExtractor,
**auxiliaries,
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]:
"""
Pre-process some information and return a callable module for p(child_span | parent_span [,parent_label])
:param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim]
:param token_mask: True for non-padding tokens.
:param span_extractor: The same module in model.
:param auxiliaries: Environment variables. You can pass extra environment variables
since the extras will be ignored.
:return:
A callable function in a closure.
The arguments for the callable object are:
- span_boundary: Shape [batch, span, 2]
- span_labels: Shape [batch, span]
- parent_mask: Shape [batch, span]
- parent_indices: Shape [batch, span]
- cursor: Shape [batch]
No return values. Everything should be done inplace.
Note the span indexing space has different meaning from training process. We don't have gold span list,
so span here refers to the predicted spans.
"""
raise NotImplementedError
@abstractmethod
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
raise NotImplementedError