import inspect import pandas as pd from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX # Utils to filter convo according to a phase from .ta_filter_utils import filter_convo def join_messages( grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper" ) -> str: """join messages from dataframe using texter an helper prefixes Args: grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**. Must have the following columns: - actor_role - message texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter". helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper". Returns: str: joined messages string separated by prefixes """ if "actor_role" not in grp: raise Exception("Column 'actor_role' not in DataFrame") if "message" not in grp: raise Exception("Column 'message' not in DataFrame") roles = grp.actor_role.replace( {"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix} ) messages = roles.str.strip() + ": " + grp.message.str.strip() return "\n".join(messages) def _get_context(grp: pd.DataFrame, **kwargs) -> str: """Get context as a str taking into account message to delete, context marker and the type of question to use. This allows for better truncation later Args: grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**. Must have the following columns: - actor_role - message - `column` column (str): column name in which the marker of the problem is Returns: pd.DataFrame: joined messages string separated by prefixes """ if "actor_role" not in grp: raise Exception("Column 'actor_role' not in DataFrame") if "message" not in grp: raise Exception("Column 'message' not in DataFrame") join_args = list(inspect.signature(join_messages).parameters) join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args} ## DEPRECATED # context_args = list(inspect.signature(get_context_on_marker).parameters) # context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args} return join_messages(grp, **join_kwargs) def load_context( messages: pd.DataFrame, question: str, message_col: str, col_type: str, inference: bool = False, **kwargs, ) -> pd.DataFrame: """Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered) Args: messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction question (str): Question to get context to message_col (str): Column where messages are col_type (str): type of message_col, can be "individual" or "joined" base_dir (str, optional): Base directory to find model base args. Defaults to "../../". Raises: Exception: If question is not supported Returns: pd.DataFrame: filtered messages according to question configuration """ if question not in QUESTION2FILTERARGS: raise Exception(f"Question {question} not supported") texter_prefix = TEXTER_PREFIX helper_prefix = HELPER_PREFIX context_data = messages.copy() def convo_cpc_get_context(grp, **kwargs): """Filter convo according to Convo Phase Classifier (CPC) predictions""" context_ = filter_convo(grp, **QUESTION2FILTERARGS[question]) return _get_context(context_, **kwargs) if col_type == "individual": if "actor_role" in context_data: context_data.dropna(subset=["actor_role"], inplace=True) if "delete_message" in context_data: context_data.delete_message.replace({1: True}, inplace=True) context_data.delete_message.fillna(False, inplace=True) context_data = ( context_data.groupby("conversation_id") .apply( convo_cpc_get_context, helper_prefix=helper_prefix, texter_prefix=texter_prefix, ) .rename("q_context") ) elif col_type == "joined": context_data = context_data.groupby("conversation_id")[[message_col]].max() context_data.rename(columns={message_col: "q_context"}, inplace=True) return context_data