convosim-ui-dev / models /ta_models /ta_prompt_utils.py
ivnban27-ctl's picture
training-adherence-features (#1)
f3e0ba5 verified
raw
history blame
4.6 kB
import inspect
import pandas as pd
from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX
# Utils to filter convo according to a phase
from .ta_filter_utils import filter_convo
def join_messages(
grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper"
) -> str:
"""join messages from dataframe using texter an helper prefixes
Args:
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
Must have the following columns:
- actor_role
- message
texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter".
helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper".
Returns:
str: joined messages string separated by prefixes
"""
if "actor_role" not in grp:
raise Exception("Column 'actor_role' not in DataFrame")
if "message" not in grp:
raise Exception("Column 'message' not in DataFrame")
roles = grp.actor_role.replace(
{"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix}
)
messages = roles.str.strip() + ": " + grp.message.str.strip()
return "\n".join(messages)
def _get_context(grp: pd.DataFrame, **kwargs) -> str:
"""Get context as a str taking into account message to delete, context marker
and the type of question to use. This allows for better truncation later
Args:
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
Must have the following columns:
- actor_role
- message
- `column`
column (str): column name in which the marker of the problem is
Returns:
pd.DataFrame: joined messages string separated by prefixes
"""
if "actor_role" not in grp:
raise Exception("Column 'actor_role' not in DataFrame")
if "message" not in grp:
raise Exception("Column 'message' not in DataFrame")
join_args = list(inspect.signature(join_messages).parameters)
join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args}
## DEPRECATED
# context_args = list(inspect.signature(get_context_on_marker).parameters)
# context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args}
return join_messages(grp, **join_kwargs)
def load_context(
messages: pd.DataFrame,
question: str,
message_col: str,
col_type: str,
inference: bool = False,
**kwargs,
) -> pd.DataFrame:
"""Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered)
Args:
messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction
question (str): Question to get context to
message_col (str): Column where messages are
col_type (str): type of message_col, can be "individual" or "joined"
base_dir (str, optional): Base directory to find model base args. Defaults to "../../".
Raises:
Exception: If question is not supported
Returns:
pd.DataFrame: filtered messages according to question configuration
"""
if question not in QUESTION2FILTERARGS:
raise Exception(f"Question {question} not supported")
texter_prefix = TEXTER_PREFIX
helper_prefix = HELPER_PREFIX
context_data = messages.copy()
def convo_cpc_get_context(grp, **kwargs):
"""Filter convo according to Convo Phase Classifier (CPC) predictions"""
context_ = filter_convo(grp, **QUESTION2FILTERARGS[question])
return _get_context(context_, **kwargs)
if col_type == "individual":
if "actor_role" in context_data:
context_data.dropna(subset=["actor_role"], inplace=True)
if "delete_message" in context_data:
context_data.delete_message.replace({1: True}, inplace=True)
context_data.delete_message.fillna(False, inplace=True)
context_data = (
context_data.groupby("conversation_id")
.apply(
convo_cpc_get_context,
helper_prefix=helper_prefix,
texter_prefix=texter_prefix,
)
.rename("q_context")
)
elif col_type == "joined":
context_data = context_data.groupby("conversation_id")[[message_col]].max()
context_data.rename(columns={message_col: "q_context"}, inplace=True)
return context_data