NLP Course documentation

Búsqueda semántica con FAISS

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Búsqueda semántica con FAISS

Ask a Question Open In Colab Open In Studio Lab

En la sección 5 creamos un dataset de issues y comentarios del repositorio de GitHub de 🤗 Datasets. En esta sección usaremos esta información para construir un motor de búsqueda que nos ayude a responder nuestras preguntas más apremiantes sobre la librería.

Usando embeddings para la búsqueda semántica

Como vimos en el Capítulo 1, los modelos de lenguaje basados en Transformers representan cada token en un texto como un vector de embeddings. Resulta que podemos agrupar los embeddings individuales en representaciones vectoriales para oraciones, párrafos o (en algunos casos) documentos completos. Estos embeddings pueden ser usados para encontrar documentos similares en el corpus al calcular la similaridad del producto punto (o alguna otra métrica de similaridad) entre cada embedding y devolver los documentos con la mayor coincidencia.

En esta sección vamos a usar embeddings para desarrollar un motor de búsqueda semántica. Estos motores de búsqueda tienen varias ventajas sobre abordajes convencionales basados en la coincidencia de palabras clave en una búsqueda con los documentos.

Semantic search.

Cargando y preparando el dataset

Lo primero que tenemos que hacer es descargar el dataset de issues de GitHub, así que usaremos la librería 🤗 Hub para resolver la URL en la que está almacenado nuestro archivo en el Hub de Hugging Face:

from huggingface_hub import hf_hub_url

data_files = hf_hub_url(
    repo_id="lewtun/github-issues",
    filename="datasets-issues-with-comments.jsonl",
    repo_type="dataset",
)

Con la URL almacenada en data_files, podemos cargar el dataset remoto usando el método introducido en la sección 2:

from datasets import load_dataset

issues_dataset = load_dataset("json", data_files=data_files, split="train")
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 2855
})

Hemos especificado el conjunto train por defecto en load_dataset(), de tal manera que devuelva un objeto Dataset en vez de un DatasetDict. Lo primero que debemos hacer es filtrar los pull requests, dado que estos no se suelen usar para resolver preguntas de usuarios e introducirán ruido en nuestro motor de búsqueda. Como ya debe ser familiar para ti, podemos usar la función Dataset.filter() para excluir estas filas en nuestro dataset. A su vez, filtremos las filas que no tienen comentarios, dado que no van a darnos respuestas para las preguntas de los usuarios.

issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 771
})

Podemos ver que hay un gran número de columnas en nuestro dataset, muchas de las cuales no necesitamos para construir nuestro motor de búsqueda. Desde la perspectiva de la búsqueda, las columnas más informativas son title, body y comments, mientras que html_url nos indica un link al issue correspondiente. Usemos la función Dataset.remove_columns() para eliminar el resto:

columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 771
})

Para crear nuestros embeddings, vamos a ampliar cada comentario añadiéndole el título y el cuerpo del issue, dado que estos campos suelen incluir información de contexto útil. Dado que nuestra función comments es una lista de comentarios para cada issue, necesitamos “explotar” la columna para que cada fila sea una tupla (html_url, title, body, comment). Podemos hacer esto en Pandas con la función DataFrame.explode(), que crea una nueva fila para cada elemento en una columna que está en forma de lista, al tiempo que replica el resto de los valores de las otras columnas. Para verlo en acción, primero debemos cambiar al formato DataFrame de Pandas:

issues_dataset.set_format("pandas")
df = issues_dataset[:]

Si inspeccionamos la primera fila en este DataFrame podemos ver que hay 4 comentarios asociados con este issue:

df["comments"][0].tolist()
['the bug code locate in :\r\n    if data_args.task_name is not None:\r\n        # Downloading and loading a dataset from the hub.\r\n        datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
 'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
 'cannot connect,even by Web browser,please check that  there is some  problems。',
 'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']

Cuando “explotamos” df, queremos obtener una fila para cada uno de estos comentarios. Veamos si este es el caso:

comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url title comments body
0 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com the bug code locate in :\r\n if data_args.task_name is not None... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
1 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
2 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com cannot connect,even by Web browser,please check that there is some problems。 Hello,\r\nI am trying to run run_glue.py and it gives me this error...
3 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem... Hello,\r\nI am trying to run run_glue.py and it gives me this error...

Genial, podemos ver que las filas se han replicado y que la columna comments incluye los comentarios individuales. Ahora que hemos terminado con Pandas, podemos volver a cambiar el formato a Dataset cargando el DataFrame en memoria:

from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2842
})

¡Esto nos ha dado varios miles de comentarios con los que trabajar!

✏️ ¡Inténtalo! Prueba si puedes usar la función Dataset.map() para “explotar” la columna comments en issues_dataset sin necesidad de usar Pandas. Esto es un poco complejo; te recomendamos revisar la sección de “Batch mapping” de la documentación de 🤗 Datasets para completar esta tarea.

Ahora que tenemos un comentario para cada fila, creemos una columna comments_length que contenga el número de palabras por comentario:

comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

Podemos usar esta nueva columna para filtrar los comentarios cortos, que típicamente incluyen cosas como “cc @letwun” o “¡Gracias!”, que no son relevantes para nuestro motor de búsqueda. No hay un número preciso que debamos filtrar, pero alrededor de 15 palabras es un buen comienzo:

comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2098
})

Ahora que hemos limpiado un poco el dataset, vamos a concatenar el título, la descripción y los comentarios del issue en una nueva columna text. Como lo hemos venido haciendo, escribiremos una función para pasarla a Dataset.map():

def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

¡Por fin estamos listos para crear embeddings!

Creando embeddings de texto

En el Capítulo 2 vimos que podemos obtener embeddings usando la clase AutoModel. Todo lo que tenemos que hacer es escoger un punto de control adecuado para cargar el modelo. Afortunadamente, existe una librería llamada sentence-transformers que se especializa en crear embeddings. Como se describe en la documentación de esta librería, nuestro caso de uso es un ejemplo de búsqueda semántica asimétrica porque tenemos una pregunta corta cuya respuesta queremos encontrar en un documento más grande, como un comentario de un issue. La tabla de resumen de modelos en la documentación nos indica que el punto de control multi-qa-mpnet-base-dot-v1 tiene el mejor desempeño para la búsqueda semántica, así que lo usaremos para nuestra aplicación. También cargaremos el tokenizador usando el mismo punto de control:

from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

Para acelerar el proceso de embedding, es útil ubicar el modelo y los inputs en un dispositivo GPU, así que hagámoslo:

import torch

device = torch.device("cuda")
model.to(device)

Como mencionamos con anterioridad, queremos representar cada entrada en el corpus de issues de GitHub como un vector individual, así que necesitamos agrupar o promediar nuestros embeddings de tokes de alguna manera. Un abordaje popular es ejecutar CLS pooling en los outputs de nuestro modelo, donde simplemente vamos a recolectar el último estado oculto para el token especial [CLS]. La siguiente función nos ayudará con esto:

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

Ahora crearemos una función que va a tokenizar una lista de documentos, ubicar los tensores en la GPU, alimentarlos al modelo y aplicar CLS pooling a los outputs:

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

Podemos probar que la función sirve al pasarle la primera entrada de texto en el corpus e inspeccionando la forma de la salida:

embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])

¡Hemos convertido la primera entrada del corpus en un vector de 768 dimensiones! Ahora podemos usar Dataset.map() para aplicar nuestra función get_embeddings() a cada fila del corpus, así que creemos una columna embeddings así:

embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

Los embeddings se han convertido en arrays de NumPy, esto es porque 🤗 Datasets los necesita en este formato cuando queremos indexarlos con FAISS, que es lo que haremos a continuación.

Usando FAISS para una búsqueda eficiente por similaridad

Ahora que tenemos un dataset de embeddings, necesitamos una manera de buscar sobre ellos. Para hacerlo, usaremos una estructura especial de datos en 🤗 Datasets llamada índice FAISS. [FAISS] (https://faiss.ai/) (siglas para Facebook AI Similarity Search) es una librería que contiene algoritmos eficientes para buscar y agrupar rápidamente vectores de embeddings.

La idea básica detrás de FAISS es que crea una estructura especial de datos, llamada índice, que te permite encontrar cuáles embeddings son parecidos a un embedding de entrada. La creación de un índice FAISS en 🤗 Datasets es muy simple: usamos la función Dataset.add_faiss_index() y especificamos cuál columna del dataset queremos indexar:

embeddings_dataset.add_faiss_index(column="embeddings")

Ahora podemos hacer búsquedas sobre este índice al hacer una búsqueda del vecino más cercano con la función Dataset.get_nearest_examples(). Probémoslo al hacer el embedding de una pregunta de la siguiente manera:

question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])

Tal como en los documentos, ahora tenemos un vector de 768 dimensiones que representa la pregunta, que podemos comparar con el corpus entero para encontrar los embeddings más parecidos:

scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

La función Dataset.get_nearest_examples() devuelve una tupla de puntajes que calcula un ranking de la coincidencia entre la pregunta y el documento, así como un conjunto correspondiente de muestras (en este caso, los 5 mejores resultados). Recojámoslos en un pandas.DataFrame para ordenarlos fácilmente:

import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

Podemos iterar sobre las primeras filas para ver qué tanto coincide la pregunta con los comentarios disponibles:

for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`

We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.

Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)

I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.

----------

> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?

Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.

----------

About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.


SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`

HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""

¡No está mal! El segundo comentario parece responder la pregunta.

✏️ ¡Inténtalo! Crea tu propia pregunta y prueba si puedes encontrar una respuesta en los documentos devueltos. Puede que tengas que incrementar el parámetro k en Dataset.get_nearest_examples() para aumentar la búsqueda.

< > Update on GitHub