vectorstore

#50
by philgrey - opened

Thanks for innovate model.
Now, I wanna save document to vectorstore using this model after building sagemaker endpoint.
does this model supports embedding vector generation?
so,
response['body'].read().decode("utf-8")
have key - "vectors"?
(it seems that most mistral models don't have this key)

Mistral AI_ org

Could you provide a bit more context on how you are using the model? TGI? Inference endpoint? etc

from langchain.llms.sagemaker_endpoint import LLMContentHandler

class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
    """
    Transforms the input into bytes that can be consumed by SageMaker endpoint.
    Args:
        inputs: List of input strings.
        model_kwargs: Additional keyword arguments to be passed to the endpoint.
    Returns:
        The transformed bytes input.
    """
    # Example: inference.py expects a JSON string with a "inputs" key:
    input_str = ' '.join(inputs)
    input_str = json.dumps({"inputs": input_str, **model_kwargs})
    return input_str.encode("utf-8")

def transform_output(self, output: bytes) -> List[List[float]]:
    """
    Transforms the bytes output from the endpoint into a list of embeddings.
    Args:
        output: The bytes output from SageMaker endpoint.
    Returns:
        The transformed output - list of embeddings
    Note:
        The length of the outer list is the number of input strings.
        The length of the inner lists is the embedding dimension.
    """
    # Example: inference.py returns a JSON string with the list of
    # embeddings in a "vectors" key:
    response_json = json.loads(output.read().decode("utf-8"))
    return response_json

============================================================================================

from langchain.chains.question_answering import load_qa_chain
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint

content_handler = ContentHandler()

llms = SagemakerEndpoint(
endpoint_name="huggingface-pytorch-tgi-inference-2023-12-18-06-04-44-513",
region_name="eu-west-2",
model_kwargs={
"temperature": 0,
"maxTokens": 1024,
"numResults": 3
},
content_handler=content_handler
)

This comment has been hidden

Sign up or log in to comment