Spaces:
Sleeping
Sleeping
ivnban27-ctl
commited on
Commit
·
75b7dbb
1
Parent(s):
fe3114b
added databricks model
Browse files- app_config.py +7 -3
- app_utils.py +10 -2
- models/databricks/scenario_sim_biz.py +66 -0
app_config.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
ISSUES = ['Anxiety','Suicide']
|
2 |
-
SOURCES = ['OA_rolemodel',
|
|
|
|
|
3 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
4 |
-
"OA_finetuned":'Finetuned OpenAI'
|
|
|
|
|
5 |
|
6 |
def source2label(source):
|
7 |
return SOURCES_LAB[source]
|
8 |
|
9 |
-
ENVIRON = "
|
10 |
|
11 |
DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
12 |
DB_CONVOS = 'conversations'
|
|
|
1 |
ISSUES = ['Anxiety','Suicide']
|
2 |
+
SOURCES = ['OA_rolemodel',
|
3 |
+
# 'OA_finetuned',
|
4 |
+
"CTL_llama2"]
|
5 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
6 |
+
"OA_finetuned":'Finetuned OpenAI',
|
7 |
+
"CTL_llama2": "Custom CTL"
|
8 |
+
}
|
9 |
|
10 |
def source2label(source):
|
11 |
return SOURCES_LAB[source]
|
12 |
|
13 |
+
ENVIRON = "dev"
|
14 |
|
15 |
DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
|
16 |
DB_CONVOS = 'conversations'
|
app_utils.py
CHANGED
@@ -7,6 +7,7 @@ from langchain.memory import ConversationBufferMemory
|
|
7 |
from app_config import ENVIRON
|
8 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
9 |
from models.openai.role_models import get_role_chain, role_templates
|
|
|
10 |
from mongo_utils import new_convo
|
11 |
|
12 |
langchain.verbose = ENVIRON=="dev"
|
@@ -35,7 +36,7 @@ def change_memories(memories, username, language, changed_source=False):
|
|
35 |
if (memory not in st.session_state) or changed_source:
|
36 |
source = params['source']
|
37 |
logger.info(f"Source for memory {memory} is {source}")
|
38 |
-
if source in ('OA_rolemodel','OA_finetuned'):
|
39 |
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
|
40 |
|
41 |
if ("convo_id" in st.session_state) and changed_source:
|
@@ -63,4 +64,11 @@ def get_chain(issue, language, source, memory, temperature):
|
|
63 |
return get_finetuned_chain(OA_engine, memory, temperature)
|
64 |
elif source in ('OA_rolemodel'):
|
65 |
template = role_templates[f"{issue}-{language}"]
|
66 |
-
return get_role_chain(template, memory, temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from app_config import ENVIRON
|
8 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
9 |
from models.openai.role_models import get_role_chain, role_templates
|
10 |
+
from models.databricks.scenario_sim_biz import get_databricks_chain
|
11 |
from mongo_utils import new_convo
|
12 |
|
13 |
langchain.verbose = ENVIRON=="dev"
|
|
|
36 |
if (memory not in st.session_state) or changed_source:
|
37 |
source = params['source']
|
38 |
logger.info(f"Source for memory {memory} is {source}")
|
39 |
+
if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2"):
|
40 |
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
|
41 |
|
42 |
if ("convo_id" in st.session_state) and changed_source:
|
|
|
64 |
return get_finetuned_chain(OA_engine, memory, temperature)
|
65 |
elif source in ('OA_rolemodel'):
|
66 |
template = role_templates[f"{issue}-{language}"]
|
67 |
+
return get_role_chain(template, memory, temperature)
|
68 |
+
elif source in ('CTL_llama2'):
|
69 |
+
if language == "English":
|
70 |
+
language = "en"
|
71 |
+
elif language == "Spanish":
|
72 |
+
language = "es"
|
73 |
+
return get_databricks_chain(issue, language, memory, temperature)
|
74 |
+
|
models/databricks/scenario_sim_biz.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import logging
|
5 |
+
from models.custom_parsers import CustomStringOutputParser
|
6 |
+
from langchain.chains import ConversationChain
|
7 |
+
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
8 |
+
from langchain_core.language_models.llms import LLM
|
9 |
+
from langchain.prompts import PromptTemplate
|
10 |
+
|
11 |
+
from typing import Any, List, Mapping, Optional, Dict
|
12 |
+
|
13 |
+
class DatabricksCustomLLM(LLM):
|
14 |
+
issue:str
|
15 |
+
language:str
|
16 |
+
temperature:float = 0.8
|
17 |
+
db_url:str = os.environ['DATABRICKS_URL']
|
18 |
+
headers:Mapping[str,str] = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'}
|
19 |
+
|
20 |
+
@property
|
21 |
+
def _llm_type(self) -> str:
|
22 |
+
return "custom_databricks"
|
23 |
+
|
24 |
+
def _call(
|
25 |
+
self,
|
26 |
+
prompt: str,
|
27 |
+
stop: Optional[List[str]] = None,
|
28 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
29 |
+
**kwargs: Any,
|
30 |
+
) -> str:
|
31 |
+
data_ = {'inputs': {
|
32 |
+
'prompt': [prompt],
|
33 |
+
'issue': [self.issue],
|
34 |
+
'language': [self.language],
|
35 |
+
'temperature': [self.temperature]
|
36 |
+
}}
|
37 |
+
data_json = json.dumps(data_, allow_nan=True)
|
38 |
+
response = requests.request(method='POST', headers=self.headers, url=self.db_url, data=data_json)
|
39 |
+
|
40 |
+
if response.status_code != 200:
|
41 |
+
raise Exception(f'Request failed with status {response.status_code}, {response.text}')
|
42 |
+
return response.json()["predictions"][0]["generated_text"]
|
43 |
+
|
44 |
+
_DATABRICKS_TEMPLATE_ = """{history}
|
45 |
+
helper: {input}
|
46 |
+
texter:"""
|
47 |
+
|
48 |
+
def get_databricks_chain(issue, language, memory, temperature=0.8):
|
49 |
+
|
50 |
+
PROMPT = PromptTemplate(
|
51 |
+
input_variables=['history', 'input'],
|
52 |
+
template=_DATABRICKS_TEMPLATE_
|
53 |
+
)
|
54 |
+
llm = DatabricksCustomLLM(
|
55 |
+
issue=issue,
|
56 |
+
language=language,
|
57 |
+
temperature=temperature
|
58 |
+
)
|
59 |
+
llm_chain = ConversationChain(
|
60 |
+
llm=llm,
|
61 |
+
prompt=PROMPT,
|
62 |
+
memory=memory,
|
63 |
+
output_parser=CustomStringOutputParser()
|
64 |
+
)
|
65 |
+
logging.debug(f"loaded Databricks Scenario Sim model")
|
66 |
+
return llm_chain, "helper:"
|