|
import nltk |
|
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder |
|
from nltk.corpus import stopwords |
|
import spacy |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from datetime import datetime, timedelta |
|
import numpy as np |
|
import heapq |
|
from concurrent.futures import ThreadPoolExecutor |
|
from annoy import AnnoyIndex |
|
from transformers import pipeline |
|
from rank_bm25 import BM25Okapi |
|
from functools import partial |
|
import ssl |
|
|
|
try: |
|
_create_unverified_https_context = ssl._create_unverified_context |
|
except AttributeError: |
|
pass |
|
else: |
|
ssl._create_default_https_context = _create_unverified_https_context |
|
|
|
|
|
import sys |
|
import subprocess |
|
|
|
def download_spacy_model(model_name): |
|
print(f"Downloading spaCy model: {model_name}") |
|
subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name]) |
|
print(f"Model {model_name} downloaded successfully") |
|
|
|
|
|
try: |
|
nlp = spacy.load('en_core_web_sm') |
|
except OSError: |
|
|
|
download_spacy_model('en_core_web_sm') |
|
|
|
import spacy |
|
nlp = spacy.load('en_core_web_sm') |
|
|
|
|
|
print("spaCy model loaded successfully") |
|
|
|
nltk.download('stopwords') |
|
nltk.download('punkt') |
|
|
|
q = 0 |
|
|
|
def get_keywords(text, cache): |
|
global q |
|
if q % 1000 == 0: |
|
print(q) |
|
q += 1 |
|
if text in cache: |
|
return cache[text] |
|
|
|
doc = nlp(text) |
|
|
|
keywords = [] |
|
for token in doc: |
|
if token.pos_ in ['NOUN', 'PROPN', 'VERB']: |
|
keywords.append(token.text.lower()) |
|
|
|
stop_words = set(stopwords.words('english')) |
|
keywords = [word for word in keywords if word not in stop_words] |
|
|
|
bigram_measures = BigramAssocMeasures() |
|
finder = BigramCollocationFinder.from_words([token.text for token in doc]) |
|
bigrams = finder.nbest(bigram_measures.pmi, 10) |
|
keywords.extend([' '.join(bigram) for bigram in bigrams]) |
|
|
|
cache[text] = keywords |
|
return keywords |
|
|
|
def calculate_weight(message, sender_messages, cache): |
|
message_time = datetime.strptime(message[1], '%Y-%m-%d %H:%M:%S') |
|
recent_messages = sender_messages[np.abs((np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in sender_messages]) - message_time).astype('timedelta64[s]').astype(int) <= 5 * 3600)] |
|
recent_keywords = [get_keywords(m[2], cache) for m in recent_messages] |
|
keyword_counts = [sum([k.count(keyword) for k in recent_keywords]) for keyword in get_keywords(message[2], cache)] |
|
weight = sum(keyword_counts) |
|
return weight |
|
|
|
class ChatDatabase: |
|
def __init__(self, filename): |
|
self.filename = filename |
|
self.messages = [] |
|
self.messages_array = None |
|
self.sender_array = None |
|
self.load_messages() |
|
self.index = None |
|
self.tfidf = None |
|
|
|
def load_messages(self): |
|
with open(self.filename, 'a') as file: |
|
pass |
|
with open(self.filename, 'r') as f: |
|
for line in f: |
|
parts = line.strip().split('\t') |
|
if len(parts) == 4: |
|
sender, time, text, tag = parts |
|
else: |
|
sender, time, text = parts |
|
tag = None |
|
message = (sender, time, text, tag) |
|
self.messages.append(message) |
|
|
|
self.messages_array = np.array(self.messages, dtype=object) |
|
print(self.messages_array,'hihii') |
|
if len(self.messages_array)==0: |
|
self.sender_array = [] |
|
else: |
|
self.sender_array = self.messages_array[:, 0] |
|
print(f'Database loaded. Number of messages: {len(self.messages_array)}') |
|
|
|
def add_message(self, sender, time, text, tag=None): |
|
message = np.array((sender, time, text, tag)).flatten() |
|
self.messages.append(message) |
|
self.messages_array = np.append(self.messages_array, message, axis=0) |
|
self.sender_array = np.append(self.sender_array, sender) |
|
with open(self.filename, 'a') as f: |
|
f.write(f'{sender}\t{time}\t{text}\t{tag}\n') |
|
|
|
|
|
def predict_response_separate(self, query, sender, cache): |
|
if self.messages_array is None: |
|
print("Error: messages_array is None") |
|
return None |
|
|
|
sender_messages = self.messages_array[self.sender_array == sender] |
|
|
|
if len(sender_messages) == 0: |
|
print(f"No messages found for sender: {sender}") |
|
return None |
|
|
|
query_keywords = ' '.join(get_keywords(query, cache)) |
|
query_vector = self.tfidf.transform([query_keywords]).toarray()[0] |
|
|
|
relevant_indices = self.index.get_nns_by_vector(query_vector, 1) |
|
relevant_message = sender_messages[relevant_indices[0]] |
|
|
|
next_message_index = np.where(self.sender_array != sender)[0][0] |
|
|
|
if next_message_index < len(self.messages_array): |
|
predicted_response = self.messages_array[next_message_index] |
|
return tuple(predicted_response) |
|
else: |
|
return None |
|
|
|
def get_relevant_messages(self, sender, query, N, cache, query_tag=None, n_threads=30, tag_boost=1.5): |
|
if self.messages_array is None: |
|
print("Error: messages_array is None") |
|
return [] |
|
|
|
query_keywords = query.lower().split() |
|
|
|
|
|
sender_messages = self.messages_array[self.sender_array == sender] |
|
print(f"Number of messages from sender {sender}: {len(sender_messages)}") |
|
|
|
|
|
sender_messages = self.messages_array[ |
|
(self.sender_array == sender) & |
|
np.array([any(keyword in message.lower() for keyword in query_keywords) for message in self.messages_array[:, 2]]) |
|
] |
|
|
|
|
|
if len(sender_messages) == 0: |
|
print(f"No messages found for sender: {sender} with the given keywords") |
|
return [] |
|
else: |
|
print(len(sender_messages)) |
|
|
|
def process_batch(batch, query_keywords, current_time, query_tag): |
|
batch_keywords = [get_keywords(message[2], cache) for message in batch] |
|
bm25 = BM25Okapi(batch_keywords) |
|
bm25_scores = bm25.get_scores(query_keywords) |
|
|
|
time_scores = 1 / (1 + (current_time - np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in batch])).astype('timedelta64[D]').astype(int)) |
|
|
|
tag_scores = np.where(np.array([m[3] for m in batch]) == query_tag, tag_boost, 1) |
|
|
|
combined_scores = 0.6 * np.array(bm25_scores) + 0.2 * time_scores + 0.2 * tag_scores |
|
|
|
return combined_scores, batch |
|
|
|
current_time = datetime.now() |
|
|
|
batch_size = max(1, len(sender_messages) // n_threads) |
|
batches = [sender_messages[i:i+batch_size] for i in range(0, len(sender_messages), batch_size)] |
|
|
|
with ThreadPoolExecutor(max_workers=n_threads) as executor: |
|
process_func = partial(process_batch, query_keywords=query_keywords, current_time=current_time, query_tag=query_tag) |
|
results = list(executor.map(process_func, batches)) |
|
|
|
all_scores = np.concatenate([r[0] for r in results]) |
|
all_messages = np.concatenate([r[1] for r in results]) |
|
|
|
top_indices = np.argsort(all_scores)[-N:][::-1] |
|
relevant_messages = all_messages[top_indices] |
|
|
|
return relevant_messages.tolist() |
|
|
|
def generate_response(self, query, sender, cache, query_tag=None): |
|
relevant_messages = self.get_relevant_messages(sender, query, 5, cache, query_tag) |
|
|
|
context = ' '.join([message[2] for message in relevant_messages]) |
|
|
|
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B') |
|
response = generator(f'{context} {query}', max_length=100, do_sample=True)[0]['generated_text'] |
|
|
|
response = response.split(query)[-1].strip() |
|
|
|
return response |
|
|
|
|
|
''' |
|
# Usage example |
|
db = ChatDatabase('memory.txt') |
|
|
|
# Example 1: Get relevant messages |
|
query = 'fisical' |
|
sender = 'Arcana' |
|
N = 10 |
|
cache = {} |
|
query_tag = None |
|
|
|
relevant_messages = db.get_relevant_messages(sender, query, N, cache, query_tag) |
|
|
|
print("Relevant messages:") |
|
for message in relevant_messages: |
|
print(f"Sender: {message[0]}, Time: {message[1]}, Tag: {message[3]}") |
|
print(f"Message: {message[2][:100]}...") |
|
print() |
|
|
|
# Example 2: Predict response (using the original method) |
|
query = "what was that?" |
|
sender = 'David' |
|
|
|
db.build_index_separate(cache) |
|
predicted_response = db.predict_response_separate(query, sender, cache) |
|
|
|
print("\nPredicted response:") |
|
if predicted_response is not None: |
|
print(f"Sender: {predicted_response[0]}, Time: {predicted_response[1]}, Tag: {predicted_response[3]}") |
|
print(f"Message: {predicted_response[2][:100]}...") |
|
else: |
|
print('No predicted response found') |
|
|
|
# Example 3: Generate response |
|
query = "Let's plan a trip" |
|
sender = 'Alice' |
|
query_tag = 'travel' |
|
|
|
generated_response = db.generate_response(query, sender, cache, query_tag) |
|
|
|
print("\nGenerated response:") |
|
print(generated_response) |
|
''' |
|
|