Spaces:
Running
Running
"""Langchain Wrapper around Sambanova LLM APIs.""" | |
import json | |
from typing import Any, Dict, Generator, Iterator, List, Optional, Union | |
import requests | |
from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
from langchain_core.language_models.llms import LLM | |
from langchain_core.outputs import GenerationChunk | |
from langchain_core.pydantic_v1 import Extra | |
from langchain_core.utils import get_from_dict_or_env, pre_init | |
from langchain_core.runnables import RunnableConfig, ensure_config | |
from langchain_core.language_models.base import ( | |
LanguageModelInput, | |
) | |
from toolformers.sambanova.utils import append_to_usage_tracker | |
class SSEndpointHandler: | |
""" | |
SambaNova Systems Interface for SambaStudio model endpoints. | |
:param str host_url: Base URL of the DaaS API service | |
""" | |
def __init__(self, host_url: str, api_base_uri: str): | |
""" | |
Initialize the SSEndpointHandler. | |
:param str host_url: Base URL of the DaaS API service | |
:param str api_base_uri: Base URI of the DaaS API service | |
""" | |
self.host_url = host_url | |
self.api_base_uri = api_base_uri | |
self.http_session = requests.Session() | |
def _process_response(self, response: requests.Response) -> Dict: | |
""" | |
Processes the API response and returns the resulting dict. | |
All resulting dicts, regardless of success or failure, will contain the | |
`status_code` key with the API response status code. | |
If the API returned an error, the resulting dict will contain the key | |
`detail` with the error message. | |
If the API call was successful, the resulting dict will contain the key | |
`data` with the response data. | |
:param requests.Response response: the response object to process | |
:return: the response dict | |
:type: dict | |
""" | |
result: Dict[str, Any] = {} | |
try: | |
result = response.json() | |
except Exception as e: | |
result['detail'] = str(e) | |
if 'status_code' not in result: | |
result['status_code'] = response.status_code | |
return result | |
def _process_streaming_response( | |
self, | |
response: requests.Response, | |
) -> Generator[Dict, None, None]: | |
"""Process the streaming response""" | |
if 'api/predict/nlp' in self.api_base_uri: | |
try: | |
import sseclient | |
except ImportError: | |
raise ImportError( | |
'could not import sseclient library' 'Please install it with `pip install sseclient-py`.' | |
) | |
client = sseclient.SSEClient(response) | |
close_conn = False | |
for event in client.events(): | |
if event.event == 'error_event': | |
close_conn = True | |
chunk = { | |
'event': event.event, | |
'data': event.data, | |
'status_code': response.status_code, | |
} | |
yield chunk | |
if close_conn: | |
client.close() | |
elif 'api/v2/predict/generic' in self.api_base_uri or 'api/predict/generic' in self.api_base_uri: | |
try: | |
for line in response.iter_lines(): | |
chunk = json.loads(line) | |
if 'status_code' not in chunk: | |
chunk['status_code'] = response.status_code | |
yield chunk | |
except Exception as e: | |
raise RuntimeError(f'Error processing streaming response: {e}') | |
else: | |
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') | |
def _get_full_url(self, path: str) -> str: | |
""" | |
Return the full API URL for a given path. | |
:param str path: the sub-path | |
:returns: the full API URL for the sub-path | |
:type: str | |
""" | |
return f'{self.host_url}/{self.api_base_uri}/{path}' | |
def nlp_predict( | |
self, | |
project: str, | |
endpoint: str, | |
key: str, | |
input: Union[List[str], str], | |
params: Optional[str] = '', | |
stream: bool = False, | |
) -> Dict: | |
""" | |
NLP predict using inline input string. | |
:param str project: Project ID in which the endpoint exists | |
:param str endpoint: Endpoint ID | |
:param str key: API Key | |
:param str input_str: Input string | |
:param str params: Input params string | |
:returns: Prediction results | |
:type: dict | |
""" | |
if isinstance(input, str): | |
input = [input] | |
if 'api/predict/nlp' in self.api_base_uri: | |
if params: | |
data = {'inputs': input, 'params': json.loads(params)} | |
else: | |
data = {'inputs': input} | |
elif 'api/v2/predict/generic' in self.api_base_uri: | |
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)] | |
if params: | |
data = {'items': items, 'params': json.loads(params)} | |
else: | |
data = {'items': items} | |
elif 'api/predict/generic' in self.api_base_uri: | |
if params: | |
data = {'instances': input, 'params': json.loads(params)} | |
else: | |
data = {'instances': input} | |
else: | |
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') | |
response = self.http_session.post( | |
self._get_full_url(f'{project}/{endpoint}'), | |
headers={'key': key}, | |
json=data, | |
) | |
return self._process_response(response) | |
def nlp_predict_stream( | |
self, | |
project: str, | |
endpoint: str, | |
key: str, | |
input: Union[List[str], str], | |
params: Optional[str] = '', | |
) -> Iterator[Dict]: | |
""" | |
NLP predict using inline input string. | |
:param str project: Project ID in which the endpoint exists | |
:param str endpoint: Endpoint ID | |
:param str key: API Key | |
:param str input_str: Input string | |
:param str params: Input params string | |
:returns: Prediction results | |
:type: dict | |
""" | |
if 'api/predict/nlp' in self.api_base_uri: | |
if isinstance(input, str): | |
input = [input] | |
if params: | |
data = {'inputs': input, 'params': json.loads(params)} | |
else: | |
data = {'inputs': input} | |
elif 'api/v2/predict/generic' in self.api_base_uri: | |
if isinstance(input, str): | |
input = [input] | |
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)] | |
if params: | |
data = {'items': items, 'params': json.loads(params)} | |
else: | |
data = {'items': items} | |
elif 'api/predict/generic' in self.api_base_uri: | |
if isinstance(input, list): | |
input = input[0] | |
if params: | |
data = {'instance': input, 'params': json.loads(params)} | |
else: | |
data = {'instance': input} | |
else: | |
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented') | |
# Streaming output | |
response = self.http_session.post( | |
self._get_full_url(f'stream/{project}/{endpoint}'), | |
headers={'key': key}, | |
json=data, | |
stream=True, | |
) | |
for chunk in self._process_streaming_response(response): | |
yield chunk | |
class SambaStudio(LLM): | |
""" | |
SambaStudio large language models. | |
To use, you should have the environment variables | |
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL. | |
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI. | |
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID. | |
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID. | |
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key. | |
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite | |
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html | |
Example: | |
.. code-block:: python | |
from langchain_community.llms.sambanova import SambaStudio | |
SambaStudio( | |
sambastudio_base_url="your-SambaStudio-environment-URL", | |
sambastudio_base_uri="your-SambaStudio-base-URI", | |
sambastudio_project_id="your-SambaStudio-project-ID", | |
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID", | |
sambastudio_api_key="your-SambaStudio-endpoint-API-key, | |
streaming=False | |
model_kwargs={ | |
"do_sample": False, | |
"max_tokens_to_generate": 1000, | |
"temperature": 0.7, | |
"top_p": 1.0, | |
"repetition_penalty": 1, | |
"top_k": 50, | |
#"process_prompt": False, | |
#"select_expert": "Meta-Llama-3-8B-Instruct" | |
}, | |
) | |
""" | |
sambastudio_base_url: str = '' | |
"""Base url to use""" | |
sambastudio_base_uri: str = '' | |
"""endpoint base uri""" | |
sambastudio_project_id: str = '' | |
"""Project id on sambastudio for model""" | |
sambastudio_endpoint_id: str = '' | |
"""endpoint id on sambastudio for model""" | |
sambastudio_api_key: str = '' | |
"""sambastudio api key""" | |
model_kwargs: Optional[dict] = None | |
"""Key word arguments to pass to the model.""" | |
streaming: Optional[bool] = False | |
"""Streaming flag to get streamed response.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = 'forbid'#Extra.forbid | |
def is_lc_serializable(cls) -> bool: | |
return True | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return {**{'model_kwargs': self.model_kwargs}} | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return 'Sambastudio LLM' | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
values['sambastudio_base_url'] = get_from_dict_or_env(values, 'sambastudio_base_url', 'SAMBASTUDIO_BASE_URL') | |
values['sambastudio_base_uri'] = get_from_dict_or_env( | |
values, | |
'sambastudio_base_uri', | |
'SAMBASTUDIO_BASE_URI', | |
default='api/predict/generic', | |
) | |
values['sambastudio_project_id'] = get_from_dict_or_env( | |
values, 'sambastudio_project_id', 'SAMBASTUDIO_PROJECT_ID' | |
) | |
values['sambastudio_endpoint_id'] = get_from_dict_or_env( | |
values, 'sambastudio_endpoint_id', 'SAMBASTUDIO_ENDPOINT_ID' | |
) | |
values['sambastudio_api_key'] = get_from_dict_or_env(values, 'sambastudio_api_key', 'SAMBASTUDIO_API_KEY') | |
return values | |
def _get_tuning_params(self, stop: Optional[List[str]]) -> str: | |
""" | |
Get the tuning parameters to use when calling the LLM. | |
Args: | |
stop: Stop words to use when generating. Model output is cut off at the | |
first occurrence of any of the stop substrings. | |
Returns: | |
The tuning parameters as a JSON string. | |
""" | |
_model_kwargs = self.model_kwargs or {} | |
_kwarg_stop_sequences = _model_kwargs.get('stop_sequences', []) | |
_stop_sequences = stop or _kwarg_stop_sequences | |
# if not _kwarg_stop_sequences: | |
# _model_kwargs["stop_sequences"] = ",".join( | |
# f'"{x}"' for x in _stop_sequences | |
# ) | |
if 'api/v2/predict/generic' in self.sambastudio_base_uri: | |
tuning_params_dict = _model_kwargs | |
else: | |
tuning_params_dict = {k: {'type': type(v).__name__, 'value': str(v)} for k, v in (_model_kwargs.items())} | |
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences | |
tuning_params = json.dumps(tuning_params_dict) | |
return tuning_params | |
def _handle_nlp_predict(self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str) -> str: | |
""" | |
Perform an NLP prediction using the SambaStudio endpoint handler. | |
Args: | |
sdk: The SSEndpointHandler to use for the prediction. | |
prompt: The prompt to use for the prediction. | |
tuning_params: The tuning parameters to use for the prediction. | |
Returns: | |
The prediction result. | |
Raises: | |
ValueError: If the prediction fails. | |
""" | |
response = sdk.nlp_predict( | |
self.sambastudio_project_id, | |
self.sambastudio_endpoint_id, | |
self.sambastudio_api_key, | |
prompt, | |
tuning_params, | |
) | |
if response['status_code'] != 200: | |
optional_detail = response.get('detail') | |
if optional_detail: | |
raise RuntimeError( | |
f"Sambanova /complete call failed with status code " | |
f"{response['status_code']}.\n Details: {optional_detail}" | |
) | |
else: | |
raise RuntimeError( | |
f"Sambanova /complete call failed with status code " | |
f"{response['status_code']}.\n response {response}" | |
) | |
if 'api/predict/nlp' in self.sambastudio_base_uri: | |
return response['data'][0]['completion'] | |
elif 'api/v2/predict/generic' in self.sambastudio_base_uri: | |
return response['items'][0]['value']['completion'] | |
elif 'api/predict/generic' in self.sambastudio_base_uri: | |
return response['predictions'][0]['completion'] | |
else: | |
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri} not implemented') | |
def _handle_completion_requests(self, prompt: Union[List[str], str], stop: Optional[List[str]]) -> str: | |
""" | |
Perform a prediction using the SambaStudio endpoint handler. | |
Args: | |
prompt: The prompt to use for the prediction. | |
stop: stop sequences. | |
Returns: | |
The prediction result. | |
Raises: | |
ValueError: If the prediction fails. | |
""" | |
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri) | |
tuning_params = self._get_tuning_params(stop) | |
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) | |
def _handle_nlp_predict_stream( | |
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str | |
) -> Iterator[GenerationChunk]: | |
""" | |
Perform a streaming request to the LLM. | |
Args: | |
sdk: The SVEndpointHandler to use for the prediction. | |
prompt: The prompt to use for the prediction. | |
tuning_params: The tuning parameters to use for the prediction. | |
Returns: | |
An iterator of GenerationChunks. | |
""" | |
for chunk in sdk.nlp_predict_stream( | |
self.sambastudio_project_id, | |
self.sambastudio_endpoint_id, | |
self.sambastudio_api_key, | |
prompt, | |
tuning_params, | |
): | |
if chunk['status_code'] != 200: | |
error = chunk.get('error') | |
if error: | |
optional_code = error.get('code') | |
optional_details = error.get('details') | |
optional_message = error.get('message') | |
raise ValueError( | |
f"Sambanova /complete call failed with status code " | |
f"{chunk['status_code']}.\n" | |
f"Message: {optional_message}\n" | |
f"Details: {optional_details}\n" | |
f"Code: {optional_code}\n" | |
) | |
else: | |
raise RuntimeError( | |
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." | |
) | |
if 'api/predict/nlp' in self.sambastudio_base_uri: | |
text = json.loads(chunk['data'])['stream_token'] | |
elif 'api/v2/predict/generic' in self.sambastudio_base_uri: | |
text = chunk['result']['items'][0]['value']['stream_token'] | |
elif 'api/predict/generic' in self.sambastudio_base_uri: | |
if len(chunk['result']['responses']) > 0: | |
text = chunk['result']['responses'][0]['stream_token'] | |
else: | |
text = '' | |
else: | |
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri}' f'not implemented') | |
generated_chunk = GenerationChunk(text=text) | |
yield generated_chunk | |
def _stream( | |
self, | |
prompt: Union[List[str], str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[GenerationChunk]: | |
"""Call out to Sambanova's complete endpoint. | |
Args: | |
prompt: The prompt to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
The string generated by the model. | |
""" | |
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri) | |
tuning_params = self._get_tuning_params(stop) | |
try: | |
if self.streaming: | |
for chunk in self._handle_nlp_predict_stream(ss_endpoint, prompt, tuning_params): | |
if run_manager: | |
run_manager.on_llm_new_token(chunk.text) | |
yield chunk | |
else: | |
return | |
except Exception as e: | |
# Handle any errors raised by the inference endpoint | |
raise ValueError(f'Error raised by the inference endpoint: {e}') from e | |
def _handle_stream_request( | |
self, | |
prompt: Union[List[str], str], | |
stop: Optional[List[str]], | |
run_manager: Optional[CallbackManagerForLLMRun], | |
kwargs: Dict[str, Any], | |
) -> str: | |
""" | |
Perform a streaming request to the LLM. | |
Args: | |
prompt: The prompt to generate from. | |
stop: Stop words to use when generating. Model output is cut off at the | |
first occurrence of any of the stop substrings. | |
run_manager: Callback manager for the run. | |
**kwargs: Additional keyword arguments. directly passed | |
to the sambastudio model in API call. | |
Returns: | |
The model output as a string. | |
""" | |
completion = '' | |
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs): | |
completion += chunk.text | |
return completion | |
def _call( | |
self, | |
prompt: Union[List[str], str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Call out to Sambanova's complete endpoint. | |
Args: | |
prompt: The prompt to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
The string generated by the model. | |
""" | |
if stop is not None: | |
raise Exception('stop not implemented') | |
try: | |
if self.streaming: | |
return self._handle_stream_request(prompt, stop, run_manager, kwargs) | |
return self._handle_completion_requests(prompt, stop) | |
except Exception as e: | |
# Handle any errors raised by the inference endpoint | |
raise ValueError(f'Error raised by the inference endpoint: {e}') from e | |
class SambaNovaCloud(LLM): | |
""" | |
SambaNova Cloud large language models. | |
To use, you should have the environment variables | |
``SAMBANOVA_URL`` set with your SambaNova Cloud URL. | |
``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key. | |
http://cloud.sambanova.ai/ | |
Example: | |
.. code-block:: python | |
SambaNovaCloud( | |
sambanova_url = SambaNova cloud endpoint URL, | |
sambanova_api_key = set with your SambaNova cloud API key, | |
max_tokens = mas number of tokens to generate | |
stop_tokens = list of stop tokens | |
model = model name | |
) | |
""" | |
sambanova_url: str = '' | |
"""SambaNova Cloud Url""" | |
sambanova_api_key: str = '' | |
"""SambaNova Cloud api key""" | |
max_tokens: int = 4000 | |
"""max tokens to generate""" | |
stop_tokens: list = ['<|eot_id|>'] | |
"""Stop tokens""" | |
model: str = 'llama3-8b' | |
"""LLM model expert to use""" | |
temperature: float = 0.0 | |
"""model temperature""" | |
top_p: float = 0.0 | |
"""model top p""" | |
top_k: int = 1 | |
"""model top k""" | |
stream_api: bool = True | |
"""use stream api""" | |
stream_options: dict = {'include_usage': True} | |
"""stream options, include usage to get generation metrics""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = 'forbid'#Extra.forbid | |
def is_lc_serializable(cls) -> bool: | |
return True | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return { | |
'model': self.model, | |
'max_tokens': self.max_tokens, | |
'stop': self.stop_tokens, | |
'temperature': self.temperature, | |
'top_p': self.top_p, | |
'top_k': self.top_k, | |
} | |
def invoke( | |
self, | |
input: LanguageModelInput, | |
config: Optional[RunnableConfig] = None, | |
*, | |
stop: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> str: | |
config = ensure_config(config) | |
print('Invoking SambaNovaCloud with input:', input) | |
response = self.generate_prompt( | |
[self._convert_input(input)], | |
stop=stop, | |
callbacks=config.get("callbacks"), | |
tags=config.get("tags"), | |
metadata=config.get("metadata"), | |
run_name=config.get("run_name"), | |
run_id=config.pop("run_id", None), | |
**kwargs, | |
) | |
run_infos = response.run | |
if len(run_infos) > 1: | |
raise NotImplementedError('Multiple runs not supported') | |
run_id = run_infos[0].run_id | |
#print('Raw response:', response.run) | |
#print('Run ID:', run_id) | |
#print(USAGE_TRACKER) | |
#if run_id in USAGE_TRACKER: | |
# print('Usage:', USAGE_TRACKER[run_id]) | |
#return response | |
return ( | |
response | |
.generations[0][0] | |
.text | |
) | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return 'SambaNova Cloud' | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
values['sambanova_url'] = get_from_dict_or_env( | |
values, 'sambanova_url', 'SAMBANOVA_URL', default='https://api.sambanova.ai/v1/chat/completions' | |
) | |
values['sambanova_api_key'] = get_from_dict_or_env(values, 'sambanova_api_key', 'SAMBANOVA_API_KEY') | |
return values | |
def _handle_nlp_predict_stream( | |
self, | |
prompt: Union[List[str], str], | |
stop: List[str], | |
) -> Iterator[GenerationChunk]: | |
""" | |
Perform a streaming request to the LLM. | |
Args: | |
prompt: The prompt to use for the prediction. | |
stop: list of stop tokens | |
Returns: | |
An iterator of GenerationChunks. | |
""" | |
try: | |
import sseclient | |
except ImportError: | |
raise ImportError('could not import sseclient library' 'Please install it with `pip install sseclient-py`.') | |
try: | |
formatted_prompt = json.loads(prompt) | |
except: | |
formatted_prompt = [{'role': 'user', 'content': prompt}] | |
http_session = requests.Session() | |
if not stop: | |
stop = self.stop_tokens | |
data = { | |
'messages': formatted_prompt, | |
'max_tokens': self.max_tokens, | |
'stop': stop, | |
'model': self.model, | |
'temperature': self.temperature, | |
'top_p': self.top_p, | |
'top_k': self.top_k, | |
'stream': self.stream_api, | |
'stream_options': self.stream_options, | |
} | |
# Streaming output | |
response = http_session.post( | |
self.sambanova_url, | |
headers={'Authorization': f'Bearer {self.sambanova_api_key}', 'Content-Type': 'application/json'}, | |
json=data, | |
stream=True, | |
) | |
client = sseclient.SSEClient(response) | |
close_conn = False | |
print('Response:', response) | |
if response.status_code != 200: | |
raise RuntimeError( | |
f'Sambanova /complete call failed with status code ' f'{response.status_code}.' f'{response.text}.' | |
) | |
for event in client.events(): | |
if event.event == 'error_event': | |
close_conn = True | |
#print('Event:', event.data) | |
chunk = { | |
'event': event.event, | |
'data': event.data, | |
'status_code': response.status_code, | |
} | |
if chunk.get('error'): | |
raise RuntimeError( | |
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." | |
) | |
try: | |
# check if the response is a final event in that case event data response is '[DONE]' | |
#if 'usage' in chunk['data']: | |
# usage = json.loads(chunk['data']) | |
# print('Usage:', usage) | |
if chunk['data'] != '[DONE]': | |
data = json.loads(chunk['data']) | |
if data.get('error'): | |
raise RuntimeError( | |
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}." | |
) | |
# check if the response is a final response with usage stats (not includes content) | |
if data.get('usage') is None: | |
# check is not "end of text" response | |
if data['choices'][0]['finish_reason'] is None: | |
text = data['choices'][0]['delta']['content'] | |
generated_chunk = GenerationChunk(text=text) | |
yield generated_chunk | |
else: | |
#if data['id'] not in USAGE_TRACKER: | |
# USAGE_TRACKER[data['id']] = [] | |
#USAGE_TRACKER[data['id']].append(data['usage']) | |
append_to_usage_tracker(data['usage']) | |
#print(f'Usage for id {data["id"]}:', data['usage']) | |
except Exception as e: | |
raise Exception(f'Error getting content chunk raw streamed response: {chunk}') | |
def _stream( | |
self, | |
prompt: Union[List[str], str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[GenerationChunk]: | |
"""Call out to Sambanova's complete endpoint. | |
Args: | |
prompt: The prompt to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
The string generated by the model. | |
""" | |
try: | |
for chunk in self._handle_nlp_predict_stream(prompt, stop): | |
if run_manager: | |
run_manager.on_llm_new_token(chunk.text) | |
yield chunk | |
except Exception as e: | |
# Handle any errors raised by the inference endpoint | |
raise ValueError(f'Error raised by the inference endpoint: {e}') from e | |
def _handle_stream_request( | |
self, | |
prompt: Union[List[str], str], | |
stop: Optional[List[str]], | |
run_manager: Optional[CallbackManagerForLLMRun], | |
kwargs: Dict[str, Any], | |
) -> str: | |
""" | |
Perform a streaming request to the LLM. | |
Args: | |
prompt: The prompt to generate from. | |
stop: Stop words to use when generating. Model output is cut off at the | |
first occurrence of any of the stop substrings. | |
run_manager: Callback manager for the run. | |
**kwargs: Additional keyword arguments. directly passed | |
to the Sambanova Cloud model in API call. | |
Returns: | |
The model output as a string. | |
""" | |
completion = '' | |
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs): | |
completion += chunk.text | |
return completion | |
def _call( | |
self, | |
prompt: Union[List[str], str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Call out to Sambanova's complete endpoint. | |
Args: | |
prompt: The prompt to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
The string generated by the model. | |
""" | |
try: | |
return self._handle_stream_request(prompt, stop, run_manager, kwargs) | |
except Exception as e: | |
# Handle any errors raised by the inference endpoint | |
raise ValueError(f'Error raised by the inference endpoint: {e}') from e |