import os import re import logging from models.custom_parsers import CustomStringOutputParser from utils.app_utils import get_random_name from app_config import ENDPOINT_NAMES, SOURCES from models.databricks.custom_databricks_llm import CustomDatabricksLLM from langchain.chains import ConversationChain from langchain.prompts import PromptTemplate from typing import Any, List, Mapping, Optional, Dict _DATABRICKS_TEMPLATE_ = """{history} helper: {input} texter:""" def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"): endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator") PROMPT = PromptTemplate( input_variables=['history', 'input'], template=_DATABRICKS_TEMPLATE_ ) llm = CustomDatabricksLLM( # endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations", endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name), bearer_token=os.environ["DATABRICKS_TOKEN"], texter_name=texter_name, issue=issue, language=language, temperature=temperature, max_tokens=256, ) llm_chain = ConversationChain( llm=llm, prompt=PROMPT, memory=memory, output_parser=CustomStringOutputParser(), verbose=True, ) logging.debug(f"loaded Databricks model") return llm_chain, None