Spaces:
Runtime error
Runtime error
import cohere | |
from annoy import AnnoyIndex | |
import numpy as np | |
import dotenv | |
import os | |
import pandas as pd | |
dotenv.load_dotenv() | |
model_name = "embed-english-v3.0" | |
api_key = os.environ['COHERE_API_KEY'] | |
input_type_embed = "search_document" | |
# Set up the cohere client | |
co = cohere.Client(api_key) | |
# Get the dataset of topics | |
topics = pd.read_csv("aicovers_topics.csv") | |
# Get the embeddings | |
list_embeds = co.embed(texts=list(topics['topic_cleaned']), model=model_name, input_type=input_type_embed).embeddings | |
# Create the search index, pass the size of embedding | |
search_index = AnnoyIndex(np.array(list_embeds).shape[1], metric='angular') | |
# Add vectors to the search index | |
for i in range(len(list_embeds)): | |
search_index.add_item(i, list_embeds[i]) | |
search_index.build(10) # 10 trees | |
search_index.save('test.ann') | |
def topic_from_caption(caption): | |
""" | |
Returns a topic from an uploaded list that is semantically similar to the input caption. | |
Args: | |
- caption (str): The image caption generated by MS Azure. | |
Returns: | |
- str: The extracted topic based on the provided caption. | |
""" | |
input_type_query = "search_query" | |
caption_embed = co.embed(texts=[caption], model=model_name, input_type=input_type_query).embeddings # embeds a caption | |
topic_ids = search_index.get_nns_by_vector(caption_embed[0], n=1, include_distances=True) # retrieves the nearest category | |
topic = topics.iloc[topic_ids[0]]['topic_cleaned'].to_string(index=False, header=False) | |
return topic | |