Spaces:
Sleeping
Sleeping
import inspect | |
import pandas as pd | |
from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX | |
# Utils to filter convo according to a phase | |
from .ta_filter_utils import filter_convo | |
def join_messages( | |
grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper" | |
) -> str: | |
"""join messages from dataframe using texter an helper prefixes | |
Args: | |
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**. | |
Must have the following columns: | |
- actor_role | |
- message | |
texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter". | |
helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper". | |
Returns: | |
str: joined messages string separated by prefixes | |
""" | |
if "actor_role" not in grp: | |
raise Exception("Column 'actor_role' not in DataFrame") | |
if "message" not in grp: | |
raise Exception("Column 'message' not in DataFrame") | |
roles = grp.actor_role.replace( | |
{"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix} | |
) | |
messages = roles.str.strip() + ": " + grp.message.str.strip() | |
return "\n".join(messages) | |
def _get_context(grp: pd.DataFrame, **kwargs) -> str: | |
"""Get context as a str taking into account message to delete, context marker | |
and the type of question to use. This allows for better truncation later | |
Args: | |
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**. | |
Must have the following columns: | |
- actor_role | |
- message | |
- `column` | |
column (str): column name in which the marker of the problem is | |
Returns: | |
pd.DataFrame: joined messages string separated by prefixes | |
""" | |
if "actor_role" not in grp: | |
raise Exception("Column 'actor_role' not in DataFrame") | |
if "message" not in grp: | |
raise Exception("Column 'message' not in DataFrame") | |
join_args = list(inspect.signature(join_messages).parameters) | |
join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args} | |
## DEPRECATED | |
# context_args = list(inspect.signature(get_context_on_marker).parameters) | |
# context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args} | |
return join_messages(grp, **join_kwargs) | |
def load_context( | |
messages: pd.DataFrame, | |
question: str, | |
message_col: str, | |
col_type: str, | |
inference: bool = False, | |
**kwargs, | |
) -> pd.DataFrame: | |
"""Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered) | |
Args: | |
messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction | |
question (str): Question to get context to | |
message_col (str): Column where messages are | |
col_type (str): type of message_col, can be "individual" or "joined" | |
base_dir (str, optional): Base directory to find model base args. Defaults to "../../". | |
Raises: | |
Exception: If question is not supported | |
Returns: | |
pd.DataFrame: filtered messages according to question configuration | |
""" | |
if question not in QUESTION2FILTERARGS: | |
raise Exception(f"Question {question} not supported") | |
texter_prefix = TEXTER_PREFIX | |
helper_prefix = HELPER_PREFIX | |
context_data = messages.copy() | |
def convo_cpc_get_context(grp, **kwargs): | |
"""Filter convo according to Convo Phase Classifier (CPC) predictions""" | |
context_ = filter_convo(grp, **QUESTION2FILTERARGS[question]) | |
return _get_context(context_, **kwargs) | |
if col_type == "individual": | |
if "actor_role" in context_data: | |
context_data.dropna(subset=["actor_role"], inplace=True) | |
if "delete_message" in context_data: | |
context_data.delete_message.replace({1: True}, inplace=True) | |
context_data.delete_message.fillna(False, inplace=True) | |
context_data = ( | |
context_data.groupby("conversation_id") | |
.apply( | |
convo_cpc_get_context, | |
helper_prefix=helper_prefix, | |
texter_prefix=texter_prefix, | |
) | |
.rename("q_context") | |
) | |
elif col_type == "joined": | |
context_data = context_data.groupby("conversation_id")[[message_col]].max() | |
context_data.rename(columns={message_col: "q_context"}, inplace=True) | |
return context_data |