nileshhanotia commited on
Commit
675f5c9
·
verified ·
1 Parent(s): 3a16d21

Create sql_generator.py

Browse files
Files changed (1) hide show
  1. sql_generator.py +70 -0
sql_generator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from utils.logger import setup_logger
3
+ from utils.model_loader import ModelLoader
4
+ from api.shopify_client import ShopifyClient
5
+
6
+ logger = setup_logger(__name__)
7
+
8
+ class SQLGenerator:
9
+ def __init__(self):
10
+ try:
11
+ self.model_name = "premai-io/prem-1B-SQL"
12
+ self.tokenizer = ModelLoader.load_model_with_retry(
13
+ self.model_name,
14
+ AutoTokenizer
15
+ )
16
+ self.model = ModelLoader.load_model_with_retry(
17
+ self.model_name,
18
+ AutoModelForCausalLM
19
+ )
20
+ self.shopify_client = ShopifyClient()
21
+ except Exception as e:
22
+ logger.error(f"Failed to initialize SQLGenerator: {str(e)}")
23
+ raise
24
+
25
+ def generate_query(self, natural_language_query):
26
+ try:
27
+ schema_info = """
28
+ CREATE TABLE products (
29
+ id DECIMAL(8,2) PRIMARY KEY,
30
+ title VARCHAR(255),
31
+ body_html VARCHAR(255),
32
+ vendor VARCHAR(255),
33
+ product_type VARCHAR(255),
34
+ created_at VARCHAR(255),
35
+ handle VARCHAR(255),
36
+ updated_at DATE,
37
+ published_at VARCHAR(255),
38
+ template_suffix VARCHAR(255),
39
+ published_scope VARCHAR(255),
40
+ tags VARCHAR(255),
41
+ status VARCHAR(255),
42
+ admin_graphql_api_id DECIMAL(8,2),
43
+ variants VARCHAR(255),
44
+ options VARCHAR(255),
45
+ images VARCHAR(255),
46
+ image VARCHAR(255)
47
+ );
48
+ """
49
+
50
+ prompt = f"""### Task: Generate a SQL query to answer the following question.
51
+ ### Database Schema: {schema_info}
52
+ ### Question: {natural_language_query}
53
+ ### SQL Query:"""
54
+
55
+ inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
56
+ outputs = self.model.generate(
57
+ inputs["input_ids"],
58
+ max_length=256,
59
+ do_sample=False,
60
+ num_return_sequences=1,
61
+ eos_token_id=self.tokenizer.eos_token_id,
62
+ pad_token_id=self.tokenizer.pad_token_id
63
+ )
64
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
65
+ except Exception as e:
66
+ logger.error(f"Query generation error: {str(e)}")
67
+ return "Failed to generate SQL query due to an error."
68
+
69
+ def fetch_shopify_data(self, endpoint):
70
+ return self.shopify_client.fetch_data(endpoint)