File size: 6,227 Bytes
c2fa877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from enum import StrEnum

import pandas as pd
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.runnables import RunnablePassthrough
from pydantic import BaseModel

from src.config import logger
from src.prompts import CharacterVoicePropertiesPrompt
from src.utils import GPTModels, get_chat_llm


class Property(StrEnum):
    gender = "gender"
    age_group = "age_group"


class CharacterProperties(BaseModel):
    gender: str
    age_group: str

    def __hash__(self):
        return hash((self.gender, self.age_group))


class AllCharactersProperties(BaseModel):
    character2props: dict[str, CharacterProperties]


class CharacterPropertiesNullable(BaseModel):
    gender: str | None
    age_group: str | None

    def __hash__(self):
        return hash((self.gender, self.age_group))


class AllCharactersPropertiesNullable(BaseModel):
    character2props: dict[str, CharacterPropertiesNullable]


class SelectVoiceChainOutput(BaseModel):
    character2props: dict[str, CharacterPropertiesNullable]
    character2voice: dict[str, str]


class VoiceSelector:
    PROPERTY_VALUES = {
        Property.gender: {"male", "female"},
        Property.age_group: {"young", "middle_aged", "old"},
    }

    def __init__(self, csv_table_fp: str):
        self.df = self.read_data_table(csv_table_fp=csv_table_fp)

    def read_data_table(self, csv_table_fp: str):
        logger.info(f'reading voice data from: "{csv_table_fp}"')
        df = pd.read_csv(csv_table_fp)
        df["age"] = df["age"].str.replace(" ", "_").str.replace("-", "_")
        return df

    def get_available_properties_str(self, prop: Property):
        vals = self.PROPERTY_VALUES[prop]
        res = ", ".join(f'"{v}"' for v in vals)
        return res

    def _get_voices_single_props(
        self, character_props: CharacterPropertiesNullable, n_characters: int
    ):
        if n_characters <= 0:
            raise ValueError(n_characters)

        df_filtered = self.df
        if val := character_props.gender:
            df_filtered = df_filtered[df_filtered["gender"] == val]
        if val := character_props.age_group:
            df_filtered = df_filtered[df_filtered["age"] == val]

        voice_ids = df_filtered.sample(n_characters)["voice_id"].to_list()
        return voice_ids

    def get_voices(self, inputs: dict) -> dict:
        character_props: AllCharactersPropertiesNullable = inputs["charater_props"]

        # check for Nones.
        # TODO: for simplicity we raise error if LLM failed to select valid property value.
        # else, we would need to implement clever mapping to avoid overlapping between voices.
        for char, props in character_props.character2props.items():
            if props.age_group is None or props.gender is None:
                raise ValueError(props)

        prop2character = {}
        for character, props in character_props.character2props.items():
            prop2character.setdefault(props, set()).add(character)

        character2voice = {}
        for props, characters in prop2character.items():
            voice_ids = self._get_voices_single_props(
                character_props=props, n_characters=len(characters)
            )
            character2voice.update(zip(characters, voice_ids))

        return character2voice

    def _remove_hallucinations_single_character(
        self, character_props: CharacterProperties
    ):
        def _process_prop(prop: Property, value: str):
            if value not in self.PROPERTY_VALUES[prop]:
                logger.warning(
                    f'LLM selected non-available {prop} value: "{value}". defaulting to None'
                )
                return None
            return value

        return CharacterPropertiesNullable(
            gender=_process_prop(prop=Property.gender, value=character_props.gender),
            age_group=_process_prop(
                prop=Property.age_group, value=character_props.age_group
            ),
        )

    def remove_hallucinations(
        self, props: AllCharactersProperties
    ) -> AllCharactersPropertiesNullable:
        res = AllCharactersPropertiesNullable(
            character2props={
                k: self._remove_hallucinations_single_character(character_props=v)
                for k, v in props.character2props.items()
            }
        )
        return res

    def pack_results(self, inputs: dict):
        character_props: AllCharactersPropertiesNullable = inputs["charater_props"]
        character2voice: dict[str, str] = inputs["character2voice"]
        return SelectVoiceChainOutput(
            character2props=character_props.character2props,
            character2voice=character2voice,
        )

    def create_voice_mapping_chain(self, llm_model: GPTModels):
        llm = get_chat_llm(llm_model=llm_model, temperature=0.0)
        llm = llm.with_structured_output(AllCharactersProperties, method="json_mode")

        output_parser = PydanticOutputParser(pydantic_object=AllCharactersProperties)
        format_instructions = output_parser.get_format_instructions()

        prompt = ChatPromptTemplate.from_messages(
            [
                SystemMessagePromptTemplate.from_template(
                    CharacterVoicePropertiesPrompt.SYSTEM
                ),
                HumanMessagePromptTemplate.from_template(
                    CharacterVoicePropertiesPrompt.USER
                ),
            ]
        )
        prompt = prompt.partial(
            **{
                "available_genders": self.get_available_properties_str(Property.gender),
                "available_age_groups": self.get_available_properties_str(
                    Property.age_group
                ),
                "format_instructions": format_instructions,
            }
        )

        chain = (
            RunnablePassthrough.assign(
                charater_props=prompt | llm | self.remove_hallucinations
            )
            | RunnablePassthrough.assign(character2voice=self.get_voices)
            | self.pack_results
        )
        return chain