File size: 503 Bytes
5953ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import os
from rex.utils.logging import logger
from src.task import MrcTaggingTask
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
task = MrcTaggingTask.from_taskdir(
"outputs/bert_mrc_ner",
load_best_model=True,
update_config={
"skip_train": True,
"debug_mode": False,
},
)
cases = ["123123", "123123"]
logger.info(f"Cases: {cases}")
ents = task.predict(cases)
logger.info(f"Results: {ents}")
|