fdurant commited on
Commit
4c96de6
·
1 Parent(s): 4b3c9e4

Add handler.py, start_emulator.sh and test scripts

Browse files
embed_single_query.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ set -x
3
+
4
+ curl \
5
+ --request POST \
6
+ --url http://localhost:4999 \
7
+ --header 'Content-Type: application/json' \
8
+ --data '{"inputs": "Please embed me"}' \
9
+ -w "\n"
embed_two_chunks.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ set -x
3
+
4
+ curl \
5
+ --request POST \
6
+ --url http://localhost:4999 \
7
+ --header 'Content-Type: application/json' \
8
+ --data '{"inputs": ["Please embed me", "And me too, please!"]}' \
9
+ -w "\n"
handler.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from colbert.infra import ColBERTConfig
4
+ from colbert.modeling.checkpoint import Checkpoint
5
+ import torch
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ MODEL = "fdurant/colbert-xm-for-inference-api"
11
+
12
+ class EndpointHandler():
13
+
14
+ def __init__(self, path=""):
15
+ self._config = ColBERTConfig(
16
+ # Defaults copied from https://github.com/datastax/ragstack-ai/blob/main/libs/colbert/ragstack_colbert/colbert_embedding_model.py
17
+ doc_maxlen=512, # Maximum number of tokens for document chunks. Should equal the chunk_size.
18
+ nbits=2, # The number bits that each dimension encodes to.
19
+ kmeans_niters=4, # Number of iterations for k-means clustering during quantization.
20
+ nranks=-1, # Number of ranks (processors) to use for distributed computing; -1 uses all available CPUs/GPUs.
21
+ checkpoint=MODEL,
22
+ )
23
+ self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3)
24
+
25
+ def __call__(self, data: Any) -> List[Dict[str, Any]]:
26
+ inputs = data["inputs"]
27
+ texts = []
28
+ if isinstance(inputs, str):
29
+ texts = [inputs]
30
+ elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs):
31
+ texts = inputs
32
+ else:
33
+ raise ValueError("Invalid input data format")
34
+ with torch.inference_mode():
35
+
36
+ if len(texts) == 1:
37
+ # It's a query
38
+ logger.info(f"Query: {texts}")
39
+ embedding = self._checkpoint.queryFromText(
40
+ queries=texts,
41
+ full_length_search=False, # Indicates whether to encode the query for a full-length search.
42
+ )
43
+ logger.info(f"Query embedding shape: {embedding.shape}")
44
+ return [
45
+ {"input": inputs, "query_embedding": embedding.tolist()[0]}
46
+ ]
47
+ elif len(texts) > 1:
48
+ # It's a batch of chunks
49
+ logger.info(f"Batch of chunks: {texts}")
50
+ embeddings, token_counts = self._checkpoint.docFromText(
51
+ docs=texts,
52
+ bsize=self._config.bsize, # Batch size
53
+ keep_dims=True, # Do NOT flatten the embeddings
54
+ return_tokens=True, # Return the tokens as well
55
+ )
56
+ for text, embedding, token_count in zip(texts, embeddings, token_counts):
57
+ logger.info(f"Chunk: {text}")
58
+ logger.info(f"Chunk embedding shape: {embedding.shape}")
59
+ logger.info(f"Chunk count: {token_count}")
60
+ return [
61
+ {"input": _input, "chunk_embedding": embedding.tolist(), "token_count": token_count.tolist()}
62
+ for _input, embedding, token_count in zip(texts, embeddings, token_counts)
63
+ ]
64
+ else:
65
+ raise ValueError("No data to process")
start_emulator.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash -e
2
+ export SHELL=/bin/bash
3
+
4
+ hf-endpoints-emulator "$@"
test_endpoint.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ import requests
4
+
5
+ URL = "http://localhost:4999/"
6
+ HEADERS = {"Content-Type": "application/json"}
7
+
8
+ def test_returns_200():
9
+ payload = {"inputs": "try me"}
10
+
11
+ response = requests.request("POST", URL, json=payload, headers=HEADERS)
12
+
13
+ assert response.status_code == 200
14
+
15
+ def test_query_returns_expected_result():
16
+ query = "try me"
17
+ payload = {"inputs": query}
18
+
19
+ response = requests.request("POST", URL, json=payload, headers=HEADERS)
20
+ response_data = response.json()
21
+
22
+ # print(response_data)
23
+
24
+ # Check structure and input
25
+ assert isinstance(response_data, list)
26
+ assert len(response_data) == 1
27
+ assert isinstance(response_data[0], dict)
28
+ assert response_data[0].get("input") == query
29
+
30
+ # Check query embedding (actually a list of embeddings, one per token in the query)
31
+ query_embedding = response_data[0].get("query_embedding")
32
+ assert isinstance(query_embedding, list)
33
+ assert len(query_embedding) == 32
34
+
35
+ # Check first of the token embeddings
36
+ first_token_embedding = query_embedding[0]
37
+ assert isinstance(first_token_embedding, list)
38
+ assert len(first_token_embedding) == 128
39
+ assert all(isinstance(value, float) for value in first_token_embedding)
40
+
41
+ def test_batch_returns_expected_result():
42
+ chunks = ["try me", "try me again and again and again"]
43
+ expected_token_counts = [11, 11] # Including start and stop tokens, I presume. Not exactly clear!
44
+ payload = {"inputs": chunks}
45
+
46
+ response = requests.request("POST", URL, json=payload, headers=HEADERS)
47
+ response_data = response.json()
48
+
49
+ # Check structure
50
+ assert isinstance(response_data, list)
51
+ assert len(response_data) == len(chunks)
52
+
53
+ for i, response_chunk in enumerate(response_data):
54
+ # Check input
55
+ assert response_chunk.get("input") == chunks[i]
56
+
57
+ # Check chunk embedding (actually a list of embeddings, one per token in the chunk)
58
+ chunk_embedding = response_chunk.get("chunk_embedding")
59
+ token_count = response_chunk.get("token_count")
60
+ assert isinstance(chunk_embedding, list)
61
+ assert len(chunk_embedding) == len(token_count)
62
+ assert len(token_count) == expected_token_counts[i]
63
+
64
+ # Check first of the token embeddings
65
+ first_token_embedding = chunk_embedding[0]
66
+ assert len(first_token_embedding) == 128
67
+ assert all(isinstance(value, float) for value in first_token_embedding)