from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from src.classmodels.inputforgeneration import InputForGeneration from errorlog.errorlog import log_error from pathlib import Path #MODEL NAME AS PER IN HUGGING FACE :- TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T current_folderpath = Path(__file__).resolve().parent tokenizer = None quantised_model = None tokenizer_path = None model_path = None additional_kwargs = { "do_sample" : True, "early_stopping" :True, "num_beams" : 5, "no_repeat_ngram_size" : 5, "truncation" : True } TASK_NAME = "text-generation" def isModelAvailable(): model_path = current_folderpath / "model" if model_path is not None and len(str(model_path).strip()) > 0: return True else: return False def isTokenizerAvailable(): tokenizer_path = current_folderpath / "tokenizer" if tokenizer_path is not None and len(str(tokenizer_path).strip()) > 0: return True else: return False def warmupTextGenerationModel(): try: if isModelAvailable() and isTokenizerAvailable(): tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) quantised_model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors = True) return "text generation model is warm up" else: return "No model/tokenizer folder found..." except Exception as ex: log_error(str(ex)) return "Issue occured when warming up the text generation model. Please try again.." def generateText(inputSettings: InputForGeneration): try: if tokenizer is not None and quantised_model is not None: pipe = pipeline(task= TASK_NAME, model= quantised_model, tokenizer = tokenizer, device_map = "auto") #formatted prompt for LLama Model prompt = f"[INST] {inputSettings.input_for_generation} [/INST]" generated_text = pipe(prompt,temperature = inputSettings.temperature, max_length = inputSettings.max_length, **additional_kwargs) if generated_text is not None and generated_text[0]['generated_text'] is not None: return generated_text[0]['generated_text'].replace("","").replace("[INST]","").replace("[/INST]","") else: #If tokenizer or model is not captured, notify as an issue in generation return None except Exception as ex: log_error(str(ex)) return ""