File size: 2,907 Bytes
1e32511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import re

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from pydantic import BaseModel

from src.prompts import SplitTextPromptV1, SplitTextPromptV2
from src.utils import GPTModels, get_chat_llm


class CharacterPhrase(BaseModel):
    character: str
    text: str


class CharacterAnnotatedText(BaseModel):
    phrases: list[CharacterPhrase]
    _characters: list[str]

    def __init__(self, **data):
        super().__init__(**data)
        self._characters = list(set(phrase.character for phrase in self.phrases))

    @property
    def characters(self):
        return self._characters

    def to_pretty_text(self):
        lines = []
        lines.append(f"characters: {self.characters}")
        lines.append("-" * 20)
        lines.extend(f"[{phrase.character}] {phrase.text}" for phrase in self.phrases)
        res = "\n".join(lines)
        return res


class SplitTextOutputV1(BaseModel):
    characters: list[str]
    parts: list[CharacterPhrase]

    def to_character_annotated_text(self):
        return CharacterAnnotatedText(phrases=self.parts)


def create_split_text_chain_v1(llm_model: GPTModels):
    llm = get_chat_llm(llm_model=llm_model, temperature=0.0)
    llm = llm.with_structured_output(SplitTextOutputV1)

    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(SplitTextPromptV1.SYSTEM),
            HumanMessagePromptTemplate.from_template(SplitTextPromptV1.USER),
        ]
    )

    chain = prompt | llm
    return chain


class SplitTextOutputV2(BaseModel):
    text_raw: str
    _phrases: list[CharacterPhrase]

    @staticmethod
    def _parse_phrases_from_xml_tags(text):
        """
        we rely on LLM to format response correctly.
        so we don't check that opening xml tags match closing ones
        """
        pattern = re.compile(r"(?:<([^<>]+)>)(.*?)(?:</\1>)")
        res = pattern.findall(text)
        res = [CharacterPhrase(character=x[0], text=x[1]) for x in res]
        return res

    def __init__(self, **data):
        super().__init__(**data)
        self._phrases = self._parse_phrases_from_xml_tags(self.text_raw)

    @property
    def phrases(self):
        return self._phrases

    def to_character_annotated_text(self):
        return CharacterAnnotatedText(phrases=self.phrases)


def create_split_text_chain_v2(llm_model: GPTModels):
    llm = get_chat_llm(llm_model=llm_model, temperature=0.0)

    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(SplitTextPromptV2.SYSTEM),
            HumanMessagePromptTemplate.from_template(SplitTextPromptV2.USER),
        ]
    )

    chain = prompt | llm | StrOutputParser() | (lambda x: SplitTextOutputV2(text_raw=x))
    return chain