license: mit
EndpointHandler
EndpointHandler
is a Python class that processes image and text data to generate embeddings and similarity scores using the ColQwen2 model—a visual retriever based on Qwen2-VL-2B-Instruct with the ColBERT strategy. This handler is optimized for retrieving documents and visual information based on their visual and textual features.
Overview
- Efficient Document Retrieval: Uses the ColQwen2 model to produce embeddings for images and text for accurate document retrieval.
- Multi-vector Representation: Generates ColBERT-style multi-vector embeddings for improved similarity search.
- Flexible Image Resolution: Supports dynamic image resolution without altering the aspect ratio, capped at 768 patches for memory efficiency.
- Device Compatibility: Automatically utilizes available CUDA devices or defaults to CPU.
Model Details
The ColQwen2 model extends Qwen2-VL-2B with a focus on vision-language tasks, making it suitable for content indexing and retrieval. Key features include:
- Training: Pre-trained with a batch size of 256 over 5 epochs, with a modified pad token.
- Input Flexibility: Handles various image resolutions without resizing, ensuring accurate multi-vector representation.
- Similarity Scoring: Utilizes a ColBERT-style scoring approach for efficient retrieval across image and text modalities.
This base version is untrained, providing deterministic initialization of the projection layer for further customization.
How to Use
The following example demonstrates how to use EndpointHandler
for processing PDF documents and text. PDF pages are converted to base64 images, which are then passed as input alongside text data to the handler.
Example Script
import torch
from pdf2image import convert_from_path
import base64
from io import BytesIO
import requests
# Function to convert PIL Image to base64 string
def pil_image_to_base64(image):
"""Converts a PIL Image to a base64 encoded string."""
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode()
# Function to convert PDF pages to base64 images
def convert_pdf_to_base64_images(pdf_path):
"""Converts PDF pages to base64 encoded images."""
pages = convert_from_path(pdf_path)
return [pil_image_to_base64(page) for page in pages]
# Function to send payload to API and retrieve response
def query_api(payload, api_url, headers):
"""Sends a POST request to the API and returns the response."""
response = requests.post(api_url, headers=headers, json=payload)
return response.json()
# Main execution
if __name__ == "__main__":
# Convert PDF pages to base64 encoded images
encoded_images = convert_pdf_to_base64_images('document.pdf')
# Prepare payload
payload = {
"inputs": [],
"image": encoded_images,
"text": ["example query text"]
}
# API configuration
API_URL = "https://your-api-url"
headers = {
"Accept": "application/json",
"Authorization": "Bearer your_access_token",
"Content-Type": "application/json"
}
# Query the API and get output
output = query_api(payload=payload, api_url=API_URL, headers=headers)
print(output)
Inputs and Outputs
Input Format
The EndpointHandler
expects a dictionary containing:
- image: A list of base64-encoded strings for images (e.g., PDF pages converted to images).
- text: A list of text strings representing queries or document contents.
- batch_size (optional): The batch size for processing images and text. Defaults to
4
.
Example payload:
{
"image": ["base64_image_string_1", "base64_image_string_2"],
"text": ["sample text 1", "sample text 2"],
"batch_size": 4
}
Output Format
The handler returns a dictionary with the following keys:
- image: List of embeddings for each image.
- text: List of embeddings for each text entry.
- scores: List of similarity scores between the image and text embeddings.
Example output:
{
"image": [[0.12, 0.34, ...], [0.56, 0.78, ...]],
"text": [[0.11, 0.22, ...], [0.33, 0.44, ...]],
"scores": [[0.87, 0.45], [0.23, 0.67]]
}
Error Handling
If any issues occur during processing (e.g., decoding images or model inference), the handler logs the error and returns an error message in the output dictionary.