Spaces:
Running
Running
import os | |
import sys | |
from typing import List, Optional | |
from langchain.embeddings.base import Embeddings | |
from langchain_qdrant import Qdrant | |
from langchain.schema import Document | |
from langchain_openai import OpenAIEmbeddings | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import Distance, VectorParams | |
import config as config | |
# Qdrant Collections Params | |
openai_embeddings = OpenAIEmbeddings(model=config.EMBEDDINGS_MODEL_NAME) | |
QDRANT_COLLECTIONS_PARAMS = {'openai_large_chunks_1500char': {'collection_name': 'openai_large_chunks_1500char', | |
'embeddings_model_name': 'text-embedding-3-large', | |
'vector_size': 3072, | |
'distance': Distance.COSINE, | |
'embeddings_model': openai_embeddings}, | |
'openai_large_chunks_500char': {'collection_name': 'openai_large_chunks_500char', | |
'embeddings_model_name': 'text-embedding-3-large', | |
'vector_size': 3072, | |
'distance': Distance.COSINE, | |
'embeddings_model': openai_embeddings}} | |
class QdrantManager: | |
'''Qdrant Manager to create a collection, add documents, and get a retriever. To see available collections, run client.get_collections()''' | |
def __init__( | |
self, | |
collection_name: str, | |
embeddings: Optional[Embeddings] = None, | |
url: Optional[str] = None, | |
api_key: Optional[str] = None, | |
vector_size: Optional[int] = None, | |
distance: Optional[Distance] = None | |
): | |
self.collection_name = collection_name | |
self.embeddings = embeddings or QDRANT_COLLECTIONS_PARAMS[collection_name]['embeddings_model'] | |
self.url = url or os.getenv("QDRANT_URL") or config[collection_name].get('url') | |
self.api_key = api_key or os.getenv("QDRANT_API_KEY") or config[collection_name].get('api_key') | |
self.vector_size = vector_size or QDRANT_COLLECTIONS_PARAMS[collection_name]['vector_size'] | |
self.distance = distance or QDRANT_COLLECTIONS_PARAMS[collection_name]['distance'] | |
self.client = QdrantClient(url=self.url, api_key=self.api_key) | |
self.qdrant = Qdrant( | |
client=self.client, | |
collection_name=self.collection_name, | |
embeddings=self.embeddings, | |
content_payload_key="page_content", | |
metadata_payload_key="metadata", | |
) | |
def create_collection(self) -> None: | |
"""Create a new collection if it doesn't exist.""" | |
collections = self.client.get_collections().collections | |
if self.collection_name not in [c.name for c in collections]: | |
self.client.create_collection( | |
collection_name=self.collection_name, | |
vectors_config=VectorParams(size=self.vector_size, distance=self.distance), | |
) | |
print(f"Collection '{self.collection_name}' created.") | |
else: | |
print(f"Collection '{self.collection_name}' already exists.") | |
def add_documents(self, documents: List[Document], batch_size: int = 1000) -> None: | |
"""Add documents to the collection in batches.""" | |
for i in range(0, len(documents), batch_size): | |
batch = documents[i:i+batch_size] | |
self.qdrant.add_documents(batch) | |
print(f"Added batch {i//batch_size + 1} ({len(batch)} documents)") | |
print(f"Total documents added: {len(documents)}") | |
def get_vectorstore(self): | |
"""Get the Qdrant vectorstore.""" | |
return self.qdrant | |
def delete_collection(self) -> None: | |
"""Delete the collection.""" | |
self.client.delete_collection(self.collection_name) | |
print(f"Collection '{self.collection_name}' deleted.") |