Mxode's picture
Update app.py
482a4be verified
raw
history blame
3.1 kB
import streamlit as st
from transformers import (
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
AutoModelForCausalLM,
)
model_dict = {
"NanoTranslator-XS": "Mxode/NanoTranslator-XS",
"NanoTranslator-S": "Mxode/NanoTranslator-S",
"NanoTranslator-M": "Mxode/NanoTranslator-M",
"NanoTranslator-M2": "Mxode/NanoTranslator-M2",
"NanoTranslator-L": "Mxode/NanoTranslator-L",
"NanoTranslator-XL": "Mxode/NanoTranslator-XL",
"NanoTranslator-XXL": "Mxode/NanoTranslator-XXL",
"NanoTranslator-XXL2": "Mxode/NanoTranslator-XXL2",
}
# initialize model
@st.cache_resource
def load_model(model_path: str):
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
return model, tokenizer
def translate(text: str, model, tokenizer: PreTrainedTokenizerBase, **kwargs):
generation_args = dict(
max_new_tokens=kwargs.pop("max_new_tokens", 64),
do_sample=kwargs.pop("do_sample", True),
temperature=kwargs.pop("temperature", 0.55),
top_p=kwargs.pop("top_p", 0.8),
top_k=kwargs.pop("top_k", 40),
**kwargs
)
prompt = "<|im_start|>" + text + "<|endoftext|>"
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
generated_ids = model.generate(model_inputs.input_ids, **generation_args)
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
st.title("NanoTranslator Demo")
st.sidebar.title("Options")
model_choice = st.sidebar.selectbox("Model", list(model_dict.keys()), index=list(model_dict.keys()).index("NanoTranslator-XXL2"))
do_sample = st.sidebar.checkbox("do_sample", value=True)
max_new_tokens = st.sidebar.slider(
"max_new_tokens", min_value=1, max_value=256, value=64
)
temperature = st.sidebar.slider(
"temperature", min_value=0.01, max_value=1.5, value=0.55, step=0.01
)
top_p = st.sidebar.slider("top_p", min_value=0.01, max_value=1.0, value=0.8, step=0.01)
top_k = st.sidebar.slider("top_k", min_value=1, max_value=100, value=40, step=1)
# 根据选择的模型加载
model_path = model_dict[model_choice]
model, tokenizer = load_model(model_path)
input_text = st.text_area(
"Please input the text to be translated (Currently supports only English to Chinese):",
"Each step of the cell cycle is monitored by internal.",
)
if st.button("translate"):
if input_text.strip():
with st.spinner("Translating..."):
translation = translate(
input_text,
model,
tokenizer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
st.success("Translated successfully!")
st.write(translation)
else:
st.warning("Please input text before translation!")