dragoman / app.py
Yurii Paniv
Handle new lines
a94d4e4
raw
history blame
1.76 kB
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import MistralForCausalLM, TextIteratorStreamer, AutoTokenizer, BitsAndBytesConfig
from time import sleep
from threading import Thread
from torch import float16
import spaces
config = PeftConfig.from_pretrained("lang-uk/dragoman")
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=float16,
bnb_4bit_use_double_quant=False,
)
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",
quantization_config=quant_config)
#device_map="auto",)
model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False)
@spaces.GPU(duration=30)
def translate(input_text):
generated_text = ""
input_text = input_text.strip()
for chunk in input_text.split("\n"):
if not chunk:
generated_text += "\n"
yield generated_text
continue
chunk = f"[INST] {chunk} [/INST]"
inputs = tokenizer([chunk], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
generated_text += new_text
yield generated_text
generated_text += "\n"
yield generated_text
iface = gr.Interface(fn=translate, inputs="text", outputs="text", examples=[["who holds this neighborhood?"]])
iface.launch()