File size: 2,112 Bytes
f3e0ba5
1e91476
f3e0ba5
9ff00d4
 
20b3b4a
1e91476
f3e0ba5
1e91476
 
9ff00d4
 
 
 
 
 
1e91476
9ff00d4
1e91476
9ff00d4
 
 
 
20b3b4a
1e91476
20b3b4a
 
 
 
1e91476
20b3b4a
 
 
 
 
 
f3e0ba5
 
 
 
 
 
20b3b4a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import streamlit as st
from streamlit.logger import get_logger
from langchain_core.messages import HumanMessage
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
from models.openai.role_models import get_role_chain, get_template_role_models
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
from models.databricks.texter_sim_llm import get_databricks_chain
from models.ta_models.cpc_utils import cpc_predict_message

logger = get_logger(__name__)

def get_chain(issue, language, source, memory, temperature, texter_name=""):
    if source in ("OA_finetuned"):
        OA_engine = finetuned_models[f"{issue}-{language}"]
        return get_finetuned_chain(OA_engine, memory, temperature)
    elif source in ('OA_rolemodel'):
        template = get_template_role_models(issue, language, texter_name=texter_name)
        return get_role_chain(template, memory, temperature)
    elif source in ('CTL_llama2'):
        if language == "English":
            language = "en"
        elif language == "Spanish":
            language = "es"
        return get_databricks_biz_chain(source, issue, language, memory, temperature)
    elif source in ('CTL_llama3'):
        if language == "English":
            language = "en"
        elif language == "Spanish":
            language = "es"
        return get_databricks_chain(source, issue, language, memory, temperature, texter_name=texter_name)
    
def custom_chain_predict(llm_chain, input, stop):
    inputs = llm_chain.prep_inputs({"input":input, "stop":stop})
    llm_chain._validate_inputs(inputs)
    outputs = llm_chain._call(inputs)
    llm_chain._validate_outputs(outputs)
    phase = cpc_predict_message(st.session_state['context'], st.session_state['last_message'])
    st.session_state['last_phase'] = phase
    logger.debug(phase)
    llm_chain.memory.chat_memory.add_user_message(
        HumanMessage(inputs['input'], response_metadata={"phase":phase})
    )
    for out in outputs[llm_chain.output_key]:
        llm_chain.memory.chat_memory.add_ai_message(out)
    return outputs[llm_chain.output_key]