Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit.logger import get_logger | |
from langchain_core.messages import HumanMessage | |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain | |
from models.openai.role_models import get_role_chain, get_template_role_models | |
from models.databricks.scenario_sim_biz import get_databricks_biz_chain | |
from models.databricks.texter_sim_llm import get_databricks_chain | |
from models.ta_models.cpc_utils import cpc_predict_message | |
logger = get_logger(__name__) | |
def get_chain(issue, language, source, memory, temperature, texter_name=""): | |
if source in ("OA_finetuned"): | |
OA_engine = finetuned_models[f"{issue}-{language}"] | |
return get_finetuned_chain(OA_engine, memory, temperature) | |
elif source in ('OA_rolemodel'): | |
template = get_template_role_models(issue, language, texter_name=texter_name) | |
return get_role_chain(template, memory, temperature) | |
elif source in ('CTL_llama2'): | |
if language == "English": | |
language = "en" | |
elif language == "Spanish": | |
language = "es" | |
return get_databricks_biz_chain(source, issue, language, memory, temperature) | |
elif source in ('CTL_llama3'): | |
if language == "English": | |
language = "en" | |
elif language == "Spanish": | |
language = "es" | |
return get_databricks_chain(source, issue, language, memory, temperature, texter_name=texter_name) | |
def custom_chain_predict(llm_chain, input, stop): | |
inputs = llm_chain.prep_inputs({"input":input, "stop":stop}) | |
llm_chain._validate_inputs(inputs) | |
outputs = llm_chain._call(inputs) | |
llm_chain._validate_outputs(outputs) | |
phase = cpc_predict_message(st.session_state['context'], st.session_state['last_message']) | |
st.session_state['last_phase'] = phase | |
logger.debug(phase) | |
llm_chain.memory.chat_memory.add_user_message( | |
HumanMessage(inputs['input'], response_metadata={"phase":phase}) | |
) | |
for out in outputs[llm_chain.output_key]: | |
llm_chain.memory.chat_memory.add_ai_message(out) | |
return outputs[llm_chain.output_key] |