File size: 3,674 Bytes
1e32511
 
 
 
 
 
 
 
c2fa877
1e32511
 
 
 
 
 
 
 
 
 
 
c2fa877
 
 
 
1e32511
 
c2fa877
 
 
 
 
 
 
 
 
 
 
1e32511
 
c2fa877
1e32511
c2fa877
1e32511
 
c2fa877
 
 
 
 
1e32511
 
 
 
 
 
 
 
 
 
 
c2fa877
1e32511
 
 
 
c2fa877
 
1e32511
 
 
c2fa877
 
 
 
 
 
 
1e32511
 
 
c2fa877
1e32511
c2fa877
 
 
 
1e32511
 
 
c2fa877
1e32511
 
c2fa877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e32511
 
c2fa877
1e32511
 
c2fa877
1e32511
c2fa877
1e32511
 
 
c2fa877
 
1e32511
 
 
c2fa877
1e32511
c2fa877
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import re

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.runnables import RunnablePassthrough
from pydantic import BaseModel

from src.prompts import SplitTextPromptV1, SplitTextPromptV2
from src.utils import GPTModels, get_chat_llm


class CharacterPhrase(BaseModel):
    character: str
    text: str


class SplitTextOutput(BaseModel):
    text_raw: str
    text_annotated: str
    _phrases: list[CharacterPhrase]
    _characters: list[str]

    @staticmethod
    def _parse_phrases_from_xml_tags(text):
        """
        we rely on LLM to format response correctly.
        so we don't check that opening xml tags match closing ones
        """
        pattern = re.compile(r"(?:<([^<>]+)>)(.*?)(?:</\1>)")
        res = pattern.findall(text)
        res = [CharacterPhrase(character=x[0], text=x[1]) for x in res]
        return res

    def __init__(self, **data):
        super().__init__(**data)
        self._phrases = self._parse_phrases_from_xml_tags(self.text_annotated)
        self._characters = list(set(phrase.character for phrase in self.phrases))
        # TODO: can apply post-processing to merge same adjacent xml tags

    @property
    def phrases(self) -> list[CharacterPhrase]:
        return self._phrases

    @property
    def characters(self) -> list[str]:
        return self._characters

    def to_pretty_text(self):
        lines = []
        lines.append(f"characters: {self.characters}")
        lines.append("-" * 20)
        lines.extend(f"[{phrase.character}] {phrase.text}" for phrase in self.phrases)
        res = "\n".join(lines)
        return res


def create_split_text_chain(llm_model: GPTModels):
    llm = get_chat_llm(llm_model=llm_model, temperature=0.0)

    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(SplitTextPromptV2.SYSTEM),
            HumanMessagePromptTemplate.from_template(SplitTextPromptV2.USER),
        ]
    )

    chain = RunnablePassthrough.assign(
        text_annotated=prompt | llm | StrOutputParser()
    ) | (
        lambda inputs: SplitTextOutput(
            text_raw=inputs["text"], text_annotated=inputs["text_annotated"]
        )
    )
    return chain


###### old code ######


class CharacterAnnotatedText(BaseModel):
    phrases: list[CharacterPhrase]
    _characters: list[str]

    def __init__(self, **data):
        super().__init__(**data)
        self._characters = list(set(phrase.character for phrase in self.phrases))

    @property
    def characters(self):
        return self._characters

    def to_pretty_text(self):
        lines = []
        lines.append(f"characters: {self.characters}")
        lines.append("-" * 20)
        lines.extend(f"[{phrase.character}] {phrase.text}" for phrase in self.phrases)
        res = "\n".join(lines)
        return res


class SplitTextOutputOld(BaseModel):
    characters: list[str]
    parts: list[CharacterPhrase]

    def to_character_annotated_text(self):
        return CharacterAnnotatedText(phrases=self.parts)


def create_split_text_chain_old(llm_model: GPTModels):
    llm = get_chat_llm(llm_model=llm_model, temperature=0.0)
    llm = llm.with_structured_output(SplitTextOutputOld, method="json_mode")

    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(SplitTextPromptV1.SYSTEM),
            HumanMessagePromptTemplate.from_template(SplitTextPromptV1.USER),
        ]
    )

    chain = prompt | llm
    return chain


## end of old code ##