Spaces:
Paused
Paused
from threading import Thread | |
from transformers import TextStreamer, TextIteratorStreamer | |
from unsloth import FastLanguageModel | |
import torch | |
import gradio as gr | |
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! | |
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | |
model_name = "Danielrahmai1991/llama32_ganjoor_adapt_basic_model_16bit_v1" | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = model_name, | |
max_seq_length = max_seq_length, | |
dtype = dtype, | |
load_in_4bit = load_in_4bit, | |
trust_remote_code=True, | |
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf | |
) | |
FastLanguageModel.for_inference(model) | |
print("model loaded") | |
import re | |
from deep_translator import (GoogleTranslator, | |
PonsTranslator, | |
LingueeTranslator, | |
MyMemoryTranslator, | |
YandexTranslator, | |
DeeplTranslator, | |
QcriTranslator, | |
single_detection, | |
batch_detection) | |
# from pyaspeller import YandexSpeller | |
# def error_correct_pyspeller(sample_text): | |
# """ grammer correction of input text""" | |
# speller = YandexSpeller() | |
# fixed = speller.spelled(sample_text) | |
# return fixed | |
# def postprocerssing(inp_text: str): | |
# """Post preocessing of the llm response""" | |
# inp_text = re.sub('<[^>]+>', '', inp_text) | |
# inp_text = inp_text.split('##', 1)[0] | |
# inp_text = error_correct_pyspeller(inp_text) | |
# return inp_text | |
# streamer = TextStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens = True) | |
messages = [] | |
def generate_text(prompt, max_length, top_p, top_k): | |
global messages | |
lang = single_detection(prompt, api_key='4ab77f25578d450f0902fb42c66d5e11') | |
# if lang == 'en': | |
# prompt = error_correct_pyspeller(prompt) | |
en_translated = GoogleTranslator(source='auto', target='en').translate(prompt) | |
messages.append({"role": "user", "content": en_translated}) | |
# messages.append({"role": "user", "content": prompt}) | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt = True, | |
return_tensors = "pt", | |
) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
max_length=int(max_length),top_p=float(top_p), do_sample=True, | |
top_k=int(top_k), streamer=streamer, temperature=0.6, repetition_penalty=1.2 | |
) | |
# _ = model.generate(input_ids, streamer = streamer, max_new_tokens = int(max_length), pad_token_id = tokenizer.eos_token_id, | |
# temperature=0.6, # Adjust this value | |
# top_k=int(top_k), # Adjust this value | |
# top_p=float(top_p), # Adjust this value | |
# repetition_penalty=1.2 | |
# ) | |
t = Thread(target=model.generate, args=(input_ids,), kwargs=generate_kwargs) | |
t.start() | |
generated_text=[] | |
for text in streamer: | |
generated_text.append(text) | |
# print(generated_text) | |
# yield "".join(generated_text) | |
yield GoogleTranslator(source='auto', target=lang).translate("".join(generated_text)) | |
messages.append({"role": "assistant", "content": "".join(generated_text)}) | |
description = """ | |
# Deploy our LLM | |
""" | |
inputs = [ | |
gr.Textbox(label="Prompt text", lines=5), | |
gr.Textbox(label="max-lenth generation", value=100), | |
gr.Slider(0.0, 1.0, label="top-p value", value=0.95), | |
gr.Textbox(label="top-k", value=50,), | |
] | |
outputs = [gr.Textbox(label="Generated Text", lines= 10)] | |
demo = gr.Interface(fn=generate_text, inputs=inputs, outputs=outputs, description=description) | |
demo.launch(debug=True, share=True) |