Spaces:
Sleeping
Sleeping
File size: 3,674 Bytes
1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 1e32511 c2fa877 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import re
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.runnables import RunnablePassthrough
from pydantic import BaseModel
from src.prompts import SplitTextPromptV1, SplitTextPromptV2
from src.utils import GPTModels, get_chat_llm
class CharacterPhrase(BaseModel):
character: str
text: str
class SplitTextOutput(BaseModel):
text_raw: str
text_annotated: str
_phrases: list[CharacterPhrase]
_characters: list[str]
@staticmethod
def _parse_phrases_from_xml_tags(text):
"""
we rely on LLM to format response correctly.
so we don't check that opening xml tags match closing ones
"""
pattern = re.compile(r"(?:<([^<>]+)>)(.*?)(?:</\1>)")
res = pattern.findall(text)
res = [CharacterPhrase(character=x[0], text=x[1]) for x in res]
return res
def __init__(self, **data):
super().__init__(**data)
self._phrases = self._parse_phrases_from_xml_tags(self.text_annotated)
self._characters = list(set(phrase.character for phrase in self.phrases))
# TODO: can apply post-processing to merge same adjacent xml tags
@property
def phrases(self) -> list[CharacterPhrase]:
return self._phrases
@property
def characters(self) -> list[str]:
return self._characters
def to_pretty_text(self):
lines = []
lines.append(f"characters: {self.characters}")
lines.append("-" * 20)
lines.extend(f"[{phrase.character}] {phrase.text}" for phrase in self.phrases)
res = "\n".join(lines)
return res
def create_split_text_chain(llm_model: GPTModels):
llm = get_chat_llm(llm_model=llm_model, temperature=0.0)
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SplitTextPromptV2.SYSTEM),
HumanMessagePromptTemplate.from_template(SplitTextPromptV2.USER),
]
)
chain = RunnablePassthrough.assign(
text_annotated=prompt | llm | StrOutputParser()
) | (
lambda inputs: SplitTextOutput(
text_raw=inputs["text"], text_annotated=inputs["text_annotated"]
)
)
return chain
###### old code ######
class CharacterAnnotatedText(BaseModel):
phrases: list[CharacterPhrase]
_characters: list[str]
def __init__(self, **data):
super().__init__(**data)
self._characters = list(set(phrase.character for phrase in self.phrases))
@property
def characters(self):
return self._characters
def to_pretty_text(self):
lines = []
lines.append(f"characters: {self.characters}")
lines.append("-" * 20)
lines.extend(f"[{phrase.character}] {phrase.text}" for phrase in self.phrases)
res = "\n".join(lines)
return res
class SplitTextOutputOld(BaseModel):
characters: list[str]
parts: list[CharacterPhrase]
def to_character_annotated_text(self):
return CharacterAnnotatedText(phrases=self.parts)
def create_split_text_chain_old(llm_model: GPTModels):
llm = get_chat_llm(llm_model=llm_model, temperature=0.0)
llm = llm.with_structured_output(SplitTextOutputOld, method="json_mode")
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SplitTextPromptV1.SYSTEM),
HumanMessagePromptTemplate.from_template(SplitTextPromptV1.USER),
]
)
chain = prompt | llm
return chain
## end of old code ##
|