|
import torch |
|
import streamlit as st |
|
from transformers import AutoTokenizer, OPTForCausalLM |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-30b") |
|
model = OPTForCausalLM.from_pretrained("facebook/galactica-30b", device_map='auto', low_cpu_mem_usage=True, torch_dtype=torch.float16) |
|
model.gradient_checkpointing_enable() |
|
return tokenizer, model |
|
|
|
|
|
st.set_page_config( |
|
page_title='BioML-SVM', |
|
layout="wide" |
|
) |
|
|
|
with st.spinner("Loading Models and Tokens..."): |
|
tokenizer, model = load_model() |
|
|
|
with st.form(key='my_form'): |
|
col1, col2 = st.columns([10, 1]) |
|
text_input = col1.text_input(label='Enter the amino sequence') |
|
with col2: |
|
st.text('') |
|
st.text('') |
|
submit_button = st.form_submit_button(label='Submit') |
|
|
|
if submit_button: |
|
st.session_state['result_done'] = False |
|
|
|
with st.spinner('Generating...'): |
|
|
|
|
|
formatted_text = f"{text_input}" |
|
input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda") |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=500 |
|
) |
|
result = tokenizer.decode(outputs[0]).replace(formatted_text, "") |
|
st.markdown(result) |
|
|
|
if 'result_done' not in st.session_state or not st.session_state.result_done: |
|
st.session_state['result_done'] = True |
|
st.session_state['previous_state'] = result |
|
else: |
|
if 'result_done' in st.session_state and st.session_state.result_done: |
|
st.markdown(st.session_state.previous_state) |
|
|
|
if 'result_done' in st.session_state and st.session_state.result_done: |
|
with st.form(key='ask_more'): |
|
col1, col2 = st.columns([10, 1]) |
|
text_input = col1.text_input(label='Ask more question') |
|
with col2: |
|
st.text('') |
|
st.text('') |
|
submit_button = st.form_submit_button(label='Submit') |
|
|
|
if submit_button: |
|
with st.spinner('Generating...'): |
|
|
|
formatted_text = f"Q:{text_input}\n\nA:\n\n" |
|
input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda") |
|
|
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_length=len(formatted_text) + 500, |
|
do_sample=True, |
|
top_k=40, |
|
num_beams=1, |
|
num_return_sequences=1 |
|
) |
|
result = tokenizer.decode(outputs[0]).replace(formatted_text, "") |
|
st.markdown(result) |
|
|