Spaces:
Runtime error
Runtime error
import os | |
import sqlite3 | |
import requests | |
import json | |
import pyttsx3 # For local TTS (if desired) | |
import speech_recognition as sr # For local STT (if desired) | |
class StarMaintAI: | |
def __init__(self, db_path): | |
self.db_path = db_path | |
self.ensure_database_exists() | |
self.connection = sqlite3.connect(db_path) | |
self.short_term_memory = [] | |
self.medium_term_memory = [] | |
self.long_term_memory = {} | |
# ModelsLab API key | |
self.modelslab_api_key = os.getenv("MODELSLAB_API_KEY", "") | |
# StarMaint-specific system rules and prompts | |
self.system_prompt = ( | |
"You are StarMaint AI, the ultimate assistant for industrial reliability and maintenance. " | |
"Your purpose is to assist with predictive maintenance, task automation, voice interactions, and knowledge management. " | |
"Be professional, concise, and helpful, adhering to the highest standards of AI performance." | |
) | |
self.rules = [ | |
"Always provide accurate and contextually relevant information.", | |
"Follow the user’s intent and prioritize clarity in responses.", | |
"Ensure all actions align with industrial safety and reliability principles.", | |
"Operate efficiently and avoid unnecessary verbosity." | |
] | |
self.load_long_term_memory() | |
def ensure_database_exists(self): | |
""" | |
Ensure the database file exists and create required tables if not. | |
""" | |
if not os.path.exists(self.db_path): | |
print(f"Database not found at {self.db_path}. Initializing new database.") | |
connection = sqlite3.connect(self.db_path) | |
cursor = connection.cursor() | |
try: | |
# Create necessary tables | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS long_term_memory ( | |
key TEXT PRIMARY KEY, | |
value TEXT | |
) | |
""") | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS prompts ( | |
title TEXT PRIMARY KEY, | |
description TEXT | |
) | |
""") | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS functions ( | |
function_name TEXT PRIMARY KEY, | |
description TEXT | |
) | |
""") | |
connection.commit() | |
except sqlite3.Error as e: | |
print(f"Error during database initialization: {e}") | |
finally: | |
connection.close() | |
# Refresh connection to ensure database is ready | |
self.connection = sqlite3.connect(self.db_path) | |
def load_long_term_memory(self): | |
""" | |
Load persistent memory from the 'long_term_memory' table. | |
""" | |
try: | |
cursor = self.connection.cursor() | |
cursor.execute("SELECT key, value FROM long_term_memory") | |
self.long_term_memory = {row[0]: row[1] for row in cursor.fetchall()} | |
except sqlite3.OperationalError as e: | |
print(f"Error loading long-term memory: {e}. Reinitializing database.") | |
self.ensure_database_exists() | |
self.long_term_memory = {} | |
def load_long_term_memory(self): | |
""" | |
Load persistent memory from the 'long_term_memory' table. | |
""" | |
try: | |
cursor = self.connection.cursor() | |
cursor.execute("SELECT key, value FROM long_term_memory") | |
self.long_term_memory = {row[0]: row[1] for row in cursor.fetchall()} | |
except sqlite3.OperationalError as e: | |
print(f"Error loading long-term memory: {e}. Reinitializing database.") | |
self.ensure_database_exists() | |
self.long_term_memory = {} | |
def load_long_term_memory(self): | |
""" | |
Load persistent memory from the 'long_term_memory' table. | |
""" | |
cursor = self.connection.cursor() | |
cursor.execute("SELECT key, value FROM long_term_memory") | |
self.long_term_memory = {row[0]: row[1] for row in cursor.fetchall()} | |
def process_user_input(self, user_input): | |
""" | |
Main pipeline: interpret user input, find relevant info, execute an action, return a response. | |
""" | |
# Step 1: Interpret | |
prompt = self.fetch_prompt("Process") | |
processed_intent = self.nlp_parse(user_input, prompt) | |
# Step 2: Retrieve relevant data from DB | |
query_data = self.find_data(processed_intent) | |
# Step 3: Execute the desired function | |
action_response = self.execute_function(query_data) | |
# Step 4: Generate final text response | |
final_response = self.generate_response(user_input, action_response) | |
return final_response | |
def fetch_prompt(self, prompt_name): | |
""" | |
Fetch a stored prompt or instruction from the 'prompts' table. | |
""" | |
cursor = self.connection.cursor() | |
cursor.execute("SELECT description FROM prompts WHERE title = ?", (prompt_name,)) | |
result = cursor.fetchone() | |
return result[0] if result else "" | |
def nlp_parse(self, text, prompt): | |
""" | |
Basic natural language parsing — can be replaced with advanced NLU. | |
""" | |
return f"Interpreted Command: {text} with prompt context: {prompt}" | |
def find_data(self, intent): | |
""" | |
Look up the function to call from a 'functions' table, using the interpreted intent. | |
""" | |
cursor = self.connection.cursor() | |
cursor.execute("SELECT * FROM functions WHERE function_name = ?", (intent,)) | |
return cursor.fetchone() # Could contain e.g. ('transcribe_audio', ...) | |
def execute_function(self, function_data): | |
""" | |
Dynamically route to the desired function based on DB data or user intent. | |
""" | |
if not function_data: | |
return "No matching function found in database." | |
function_name = function_data[0] | |
if function_name == "transcribe_audio": | |
audio_url = "https://example.com/test.wav" | |
return self.transcribe_audio(audio_url, input_language="en") | |
elif function_name == "generate_audio": | |
text_prompt = "Hello, this is a sample text for voice synthesis." | |
init_audio_url = "https://example.com/voice_clip.wav" | |
return self.generate_audio(text_prompt, init_audio_url) | |
elif function_name == "uncensored_chat": | |
chat_prompt = "Write a tagline for an ice cream shop." | |
return self.uncensored_chat_completion(chat_prompt) | |
else: | |
return f"Function '{function_name}' not recognized or not yet implemented." | |
def transcribe_audio(self, audio_url, input_language="en"): | |
""" | |
Integrates ModelsLab Speech-to-Text (Whisper) endpoint. | |
""" | |
if not self.modelslab_api_key: | |
return "API key not found; cannot transcribe audio." | |
url = "https://modelslab.com/api/v6/whisper/transcribe" | |
payload = { | |
"key": self.modelslab_api_key, | |
"audio_url": audio_url, | |
"input_language": input_language, | |
"timestamp_level": None, | |
"webhook": None, | |
"track_id": None | |
} | |
headers = {"Content-Type": "application/json"} | |
try: | |
response = requests.post(url, headers=headers, data=json.dumps(payload)) | |
return f"Transcription request sent. Response: {response.text}" | |
except Exception as e: | |
return f"Error during transcription: {e}" | |
def generate_audio(self, text_prompt, init_audio_url=None, voice_id=None, language="english"): | |
""" | |
Integrates ModelsLab Text-to-Audio (Voice Cloning / TTS). | |
""" | |
if not self.modelslab_api_key: | |
return "API key not found; cannot generate audio." | |
url = "https://modelslab.com/api/v6/voice/text_to_audio" | |
payload = { | |
"key": self.modelslab_api_key, | |
"prompt": text_prompt, | |
"language": language, | |
"webhook": None, | |
"track_id": None | |
} | |
if init_audio_url: | |
payload["init_audio"] = init_audio_url | |
elif voice_id: | |
payload["voice_id"] = voice_id | |
headers = {"Content-Type": "application/json"} | |
try: | |
response = requests.post(url, headers=headers, data=json.dumps(payload)) | |
return f"Audio generation request sent. Response: {response.text}" | |
except Exception as e: | |
return f"Error during audio generation: {e}" | |
def uncensored_chat_completion(self, prompt): | |
""" | |
Integrates ModelsLab Uncensored Chat Completions. | |
""" | |
if not self.modelslab_api_key: | |
return "API key not found; cannot complete uncensored chat." | |
base_url = "https://modelslab.com/api/uncensored-chat/v1/completions" | |
payload = { | |
"model": "ModelsLab/Llama-3.1-8b-Uncensored-Dare", | |
"prompt": prompt, | |
"max_tokens": 50, | |
"temperature": 0.7 | |
} | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {self.modelslab_api_key}" | |
} | |
try: | |
response = requests.post(base_url, headers=headers, data=json.dumps(payload)) | |
data = response.json() | |
if "choices" in data and len(data["choices"]) > 0: | |
return data["choices"][0].get("text", "") | |
else: | |
return f"Unexpected chat response: {data}" | |
except Exception as e: | |
return f"Error during uncensored chat completion: {e}" | |
def generate_response(self, user_input, action_response): | |
""" | |
Combine user input, system rules, and action response into a final message. | |
""" | |
return ( | |
f"System Prompt: {self.system_prompt}\n" | |
f"Rules: {'; '.join(self.rules)}\n" | |
f"User Input: {user_input}\n" | |
f"System Action: {action_response}" | |
) | |
def run_app(): | |
""" | |
Example main loop to run the app in a console. | |
""" | |
db_path = "central_data.db" # Adjust for your environment | |
starmaint_ai = StarMaintAI(db_path) | |
print("Welcome to StarMaint AI.") | |
while True: | |
user_input = input("You: ") | |
if user_input.lower() in ["exit", "quit"]: | |
print("Exiting application.") | |
break | |
response = starmaint_ai.process_user_input(user_input) | |
print(f"AI: {response}") | |
if __name__ == "__main__": | |
run_app() | |