File size: 503 Bytes
5953ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os

from rex.utils.logging import logger

from src.task import MrcTaggingTask

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

    task = MrcTaggingTask.from_taskdir(
        "outputs/bert_mrc_ner",
        load_best_model=True,
        update_config={
            "skip_train": True,
            "debug_mode": False,
        },
    )

    cases = ["123123", "123123"]
    logger.info(f"Cases: {cases}")

    ents = task.predict(cases)
    logger.info(f"Results: {ents}")