Spaces:
Sleeping
Sleeping
import os | |
import re | |
import logging | |
from models.custom_parsers import CustomStringOutputParser | |
from utils.app_utils import get_random_name | |
from app_config import ENDPOINT_NAMES, SOURCES | |
from models.databricks.custom_databricks_llm import CustomDatabricksLLM | |
from langchain.chains import ConversationChain | |
from langchain.prompts import PromptTemplate | |
from typing import Any, List, Mapping, Optional, Dict | |
_DATABRICKS_TEMPLATE_ = """{history} | |
helper: {input} | |
texter:""" | |
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"): | |
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator") | |
PROMPT = PromptTemplate( | |
input_variables=['history', 'input'], | |
template=_DATABRICKS_TEMPLATE_ | |
) | |
llm = CustomDatabricksLLM( | |
# endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations", | |
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name), | |
bearer_token=os.environ["DATABRICKS_TOKEN"], | |
texter_name=texter_name, | |
issue=issue, | |
language=language, | |
temperature=temperature, | |
max_tokens=256, | |
) | |
llm_chain = ConversationChain( | |
llm=llm, | |
prompt=PROMPT, | |
memory=memory, | |
output_parser=CustomStringOutputParser(), | |
verbose=True, | |
) | |
logging.debug(f"loaded Databricks model") | |
return llm_chain, None |