Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from src.classmodels.inputforgeneration import InputForGeneration | |
from errorlog.errorlog import log_error | |
from pathlib import Path | |
#MODEL NAME AS PER IN HUGGING FACE :- TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T | |
current_folderpath = Path(__file__).resolve().parent | |
tokenizer = None | |
quantised_model = None | |
tokenizer_path = None | |
model_path = None | |
additional_kwargs = { | |
"do_sample" : True, | |
"early_stopping" :True, | |
"num_beams" : 5, | |
"no_repeat_ngram_size" : 5, | |
"truncation" : True | |
} | |
TASK_NAME = "text-generation" | |
def isModelAvailable(): | |
model_path = current_folderpath / "model" | |
if model_path is not None and len(str(model_path).strip()) > 0: | |
return True | |
else: | |
return False | |
def isTokenizerAvailable(): | |
tokenizer_path = current_folderpath / "tokenizer" | |
if tokenizer_path is not None and len(str(tokenizer_path).strip()) > 0: | |
return True | |
else: | |
return False | |
def warmupTextGenerationModel(): | |
try: | |
if isModelAvailable() and isTokenizerAvailable(): | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
quantised_model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors = True) | |
return "text generation model is warm up" | |
else: | |
return "No model/tokenizer folder found..." | |
except Exception as ex: | |
log_error(str(ex)) | |
return "Issue occured when warming up the text generation model. Please try again.." | |
def generateText(inputSettings: InputForGeneration): | |
try: | |
if tokenizer is not None and quantised_model is not None: | |
pipe = pipeline(task= TASK_NAME, model= quantised_model, tokenizer = tokenizer, device_map = "auto") | |
#formatted prompt for LLama Model | |
prompt = f"<s>[INST] {inputSettings.input_for_generation} [/INST]" | |
generated_text = pipe(prompt,temperature = inputSettings.temperature, max_length = inputSettings.max_length, | |
**additional_kwargs) | |
if generated_text is not None and generated_text[0]['generated_text'] is not None: | |
return generated_text[0]['generated_text'].replace("<s>","").replace("[INST]","").replace("[/INST]","") | |
else: | |
#If tokenizer or model is not captured, notify as an issue in generation | |
return None | |
except Exception as ex: | |
log_error(str(ex)) | |
return "" | |