import torch from transformers import BertModel, BertTokenizer from sklearn.base import BaseEstimator, TransformerMixin import numpy as np import pandas as pd from sklearn.neighbors import KNeighborsClassifier class BertSentenceEmbedder(BaseEstimator, TransformerMixin): def __init__(self, device="cpu",padding_length=50): """ Args: `device`: pytorch device for inference. Either 'cpu' or a specific type of GPU. `padding_length`: The max sentence token length. Shorter sentences are padded to this length. """ self._device = device self._tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True) self._model = model.to(device) self._model.eval() self._padding_length = padding_length def transform(self, X:list) -> np.ndarray: """ Transforms sentences into embeddings Args: `X`: a dataset of sentences of shape (n_sentences,) Returns: Embeddings of the provided sentences of shape (n_sentences, embedding_dims) """ tokens = self._tokenizer( X, return_token_type_ids=False, return_attention_mask=False, padding=True, truncation=True, max_length=self._padding_length, return_tensors="pt" ) tokens = tokens["input_ids"].to(self._device) with torch.no_grad(): hidden_states = self._model( input_ids=tokens, output_hidden_states=True )["hidden_states"] embeddings = torch.cat(hidden_states[-4:], dim=-1) embeddings = torch.mean(embeddings, dim=1) return embeddings.cpu().numpy() if __name__ == "__main__": df = pd.read_csv("course_sentences.csv") embedder = BertSentenceEmbedder("mps", padding_length=1000) embeddings = embedder.transform(list(df["sentence"])) labels = df["program"] classifier = KNeighborsClassifier(n_neighbors=10) classifier.fit(embeddings, labels) num_suggestions = 10 prompt = "Covers methods currently available to address complexity, including systems thinking, model based systems engineering and life cycle governance." embedding = embedder.transform([prompt]) probs = classifier.predict_proba(embedding)[0] idx = np.argsort(-probs)[:num_suggestions] label_map = np.array(sorted(set(labels))) print(prompt, label_map[idx], probs[idx])