Spaces:
Running
Running
from enum import StrEnum | |
import pandas as pd | |
from langchain_core.output_parsers import PydanticOutputParser | |
from langchain_core.prompts import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
SystemMessagePromptTemplate, | |
) | |
from langchain_core.runnables import RunnablePassthrough | |
from pydantic import BaseModel | |
from src.config import logger | |
from src.prompts import CharacterVoicePropertiesPrompt | |
from src.utils import GPTModels, get_chat_llm | |
class Property(StrEnum): | |
gender = "gender" | |
age_group = "age_group" | |
class CharacterProperties(BaseModel): | |
gender: str | |
age_group: str | |
def __hash__(self): | |
return hash((self.gender, self.age_group)) | |
class AllCharactersProperties(BaseModel): | |
character2props: dict[str, CharacterProperties] | |
class CharacterPropertiesNullable(BaseModel): | |
gender: str | None | |
age_group: str | None | |
def __hash__(self): | |
return hash((self.gender, self.age_group)) | |
class AllCharactersPropertiesNullable(BaseModel): | |
character2props: dict[str, CharacterPropertiesNullable] | |
class SelectVoiceChainOutput(BaseModel): | |
character2props: dict[str, CharacterPropertiesNullable] | |
character2voice: dict[str, str] | |
class VoiceSelector: | |
PROPERTY_VALUES = { | |
Property.gender: {"male", "female"}, | |
Property.age_group: {"young", "middle_aged", "old"}, | |
} | |
def __init__(self, csv_table_fp: str): | |
self.df = self.read_data_table(csv_table_fp=csv_table_fp) | |
def read_data_table(self, csv_table_fp: str): | |
logger.info(f'reading voice data from: "{csv_table_fp}"') | |
df = pd.read_csv(csv_table_fp) | |
df["age"] = df["age"].str.replace(" ", "_").str.replace("-", "_") | |
return df | |
def get_available_properties_str(self, prop: Property): | |
vals = self.PROPERTY_VALUES[prop] | |
res = ", ".join(f'"{v}"' for v in vals) | |
return res | |
def _get_voices_single_props( | |
self, character_props: CharacterPropertiesNullable, n_characters: int | |
): | |
if n_characters <= 0: | |
raise ValueError(n_characters) | |
df_filtered = self.df | |
if val := character_props.gender: | |
df_filtered = df_filtered[df_filtered["gender"] == val] | |
if val := character_props.age_group: | |
df_filtered = df_filtered[df_filtered["age"] == val] | |
voice_ids = df_filtered.sample(n_characters)["voice_id"].to_list() | |
return voice_ids | |
def get_voices(self, inputs: dict) -> dict: | |
character_props: AllCharactersPropertiesNullable = inputs["charater_props"] | |
# check for Nones. | |
# TODO: for simplicity we raise error if LLM failed to select valid property value. | |
# else, we would need to implement clever mapping to avoid overlapping between voices. | |
for char, props in character_props.character2props.items(): | |
if props.age_group is None or props.gender is None: | |
raise ValueError(props) | |
prop2character = {} | |
for character, props in character_props.character2props.items(): | |
prop2character.setdefault(props, set()).add(character) | |
character2voice = {} | |
for props, characters in prop2character.items(): | |
voice_ids = self._get_voices_single_props( | |
character_props=props, n_characters=len(characters) | |
) | |
character2voice.update(zip(characters, voice_ids)) | |
return character2voice | |
def _remove_hallucinations_single_character( | |
self, character_props: CharacterProperties | |
): | |
def _process_prop(prop: Property, value: str): | |
if value not in self.PROPERTY_VALUES[prop]: | |
logger.warning( | |
f'LLM selected non-available {prop} value: "{value}". defaulting to None' | |
) | |
return None | |
return value | |
return CharacterPropertiesNullable( | |
gender=_process_prop(prop=Property.gender, value=character_props.gender), | |
age_group=_process_prop( | |
prop=Property.age_group, value=character_props.age_group | |
), | |
) | |
def remove_hallucinations( | |
self, props: AllCharactersProperties | |
) -> AllCharactersPropertiesNullable: | |
res = AllCharactersPropertiesNullable( | |
character2props={ | |
k: self._remove_hallucinations_single_character(character_props=v) | |
for k, v in props.character2props.items() | |
} | |
) | |
return res | |
def pack_results(self, inputs: dict): | |
character_props: AllCharactersPropertiesNullable = inputs["charater_props"] | |
character2voice: dict[str, str] = inputs["character2voice"] | |
return SelectVoiceChainOutput( | |
character2props=character_props.character2props, | |
character2voice=character2voice, | |
) | |
def create_voice_mapping_chain(self, llm_model: GPTModels): | |
llm = get_chat_llm(llm_model=llm_model, temperature=0.0) | |
llm = llm.with_structured_output(AllCharactersProperties, method="json_mode") | |
output_parser = PydanticOutputParser(pydantic_object=AllCharactersProperties) | |
format_instructions = output_parser.get_format_instructions() | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template( | |
CharacterVoicePropertiesPrompt.SYSTEM | |
), | |
HumanMessagePromptTemplate.from_template( | |
CharacterVoicePropertiesPrompt.USER | |
), | |
] | |
) | |
prompt = prompt.partial( | |
**{ | |
"available_genders": self.get_available_properties_str(Property.gender), | |
"available_age_groups": self.get_available_properties_str( | |
Property.age_group | |
), | |
"format_instructions": format_instructions, | |
} | |
) | |
chain = ( | |
RunnablePassthrough.assign( | |
charater_props=prompt | llm | self.remove_hallucinations | |
) | |
| RunnablePassthrough.assign(character2voice=self.get_voices) | |
| self.pack_results | |
) | |
return chain | |