#import transformers

from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
)

#load model

model = T5ForConditionalGeneration.from_pretrained('dsivakumar/text2sql')
tokenizer = T5Tokenizer.from_pretrained('dsivakumar/text2sql')

#predict function

def get_sql(query,tokenizer,model):
    source_text= "English to SQL: "+query
    source_text = ' '.join(source_text.split())
    source = tokenizer.batch_encode_plus([source_text],max_length= 128, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
    source_ids = source['input_ids'] #.squeeze()
    source_mask = source['attention_mask']#.squeeze()
    generated_ids = model.generate(
      input_ids = source_ids.to(dtype=torch.long),
      attention_mask = source_mask.to(dtype=torch.long), 
      max_length=150, 
      num_beams=2,
      repetition_penalty=2.5, 
      length_penalty=1.0, 
      early_stopping=True
      )
    preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
    return preds
 
#test

query="Show me the average age of of wines in Italy by provinces"
sql = get_sql(query,tokenizer,model)
print(sql)

#https://huggingface.co/mrm8488/t5-small-finetuned-wikiSQL
def get_sql(query):
  input_text = "translate English to SQL: %s </s>" % query
  features = tokenizer([input_text], return_tensors='pt')

  output = model.generate(input_ids=features['input_ids'], 
               attention_mask=features['attention_mask'])
  
  return tokenizer.decode(output[0])

query = "How many models were finetuned using BERT as base model?"

get_sql(query)
Downloads last month
25
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train dsivakumar/text2sql