TextGeneration / src /TinyLLama /text_generation.py
SoumyaJ's picture
Initial commit
97a778d
raw
history blame
2.58 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from src.classmodels.inputforgeneration import InputForGeneration
from errorlog.errorlog import log_error
from pathlib import Path
#MODEL NAME AS PER IN HUGGING FACE :- TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
current_folderpath = Path(__file__).resolve().parent
tokenizer = None
quantised_model = None
tokenizer_path = None
model_path = None
additional_kwargs = {
"do_sample" : True,
"early_stopping" :True,
"num_beams" : 5,
"no_repeat_ngram_size" : 5,
"truncation" : True
}
TASK_NAME = "text-generation"
def isModelAvailable():
model_path = current_folderpath / "model"
if model_path is not None and len(str(model_path).strip()) > 0:
return True
else:
return False
def isTokenizerAvailable():
tokenizer_path = current_folderpath / "tokenizer"
if tokenizer_path is not None and len(str(tokenizer_path).strip()) > 0:
return True
else:
return False
def warmupTextGenerationModel():
try:
if isModelAvailable() and isTokenizerAvailable():
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
quantised_model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors = True)
return "text generation model is warm up"
else:
return "No model/tokenizer folder found..."
except Exception as ex:
log_error(str(ex))
return "Issue occured when warming up the text generation model. Please try again.."
def generateText(inputSettings: InputForGeneration):
try:
if tokenizer is not None and quantised_model is not None:
pipe = pipeline(task= TASK_NAME, model= quantised_model, tokenizer = tokenizer, device_map = "auto")
#formatted prompt for LLama Model
prompt = f"<s>[INST] {inputSettings.input_for_generation} [/INST]"
generated_text = pipe(prompt,temperature = inputSettings.temperature, max_length = inputSettings.max_length,
**additional_kwargs)
if generated_text is not None and generated_text[0]['generated_text'] is not None:
return generated_text[0]['generated_text'].replace("<s>","").replace("[INST]","").replace("[/INST]","")
else:
#If tokenizer or model is not captured, notify as an issue in generation
return None
except Exception as ex:
log_error(str(ex))
return ""