Scriptr-Gemma / app.py
Sidharthan's picture
Resolving the configuration problem
8001965
import streamlit as st
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
import torch
import re
from transformers import StoppingCriteria, StoppingCriteriaList
import os
# Set cache directory and get token
os.environ['HF_HOME'] = '/app/cache'
hf_token = os.getenv('HF_TOKEN')
class StopWordCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_word):
self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False)
def __call__(self, input_ids, scores, **kwargs):
if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id:
return True
return False
def load_model():
try:
# Ensure cache directory exists
cache_dir = '/app/cache'
os.makedirs(cache_dir, exist_ok=True)
# Check for HF token
if not hf_token:
st.warning("HuggingFace token not found. Some models may not be accessible.")
# Check CUDA availability
if torch.cuda.is_available():
device = torch.device("cuda")
st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
st.warning("CUDA is not available. Using CPU.")
# Fine-tuned model for generating scripts
model_name = "Sidharthan/gemma2_scripter"
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
token=hf_token,
cache_dir=cache_dir
)
except Exception as e:
st.error(f"Error loading tokenizer: {str(e)}")
if "401" in str(e):
st.error("Authentication error. Please check your HuggingFace token.")
raise e
try:
# Load model with appropriate device settings
model = AutoPeftModelForCausalLM.from_pretrained(
model_name,
device_map=None, # We'll handle device placement manually
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True,
token=hf_token,
cache_dir=cache_dir
)
# Move model to device
model = model.to(device)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
if "401" in str(e):
st.error("Authentication error. Please check your HuggingFace token.")
elif "disk space" in str(e).lower():
st.error("Insufficient disk space in cache directory.")
raise e
except Exception as e:
st.error(f"General error during model loading: {str(e)}")
raise e
def generate_script(tags, model, tokenizer, params):
device = next(model.parameters()).device
# Create prompt with tags
prompt = f"<bos><start_of_turn>keywords\n{tags}<end_of_turn>\n<start_of_turn>script\n"
# Tokenize and move to device
inputs = tokenizer(prompt, return_tensors='pt')
inputs = {k: v.to(device) for k, v in inputs.items()}
stop_word = 'script'
stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)])
try:
outputs = model.generate(
**inputs,
max_length=params['max_length'],
do_sample=True,
temperature=params['temperature'],
top_p=params['top_p'],
top_k=params['top_k'],
repetition_penalty=params['repetition_penalty'],
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria
)
# Move outputs back to CPU for decoding
outputs = outputs.cpu()
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL)
response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip()
return response
except RuntimeError as e:
if "out of memory" in str(e):
st.error("GPU out of memory error. Try reducing max_length or using CPU.")
return "Error: GPU out of memory"
else:
st.error(f"Error during generation: {str(e)}")
return f"Error during generation: {str(e)}"
def main():
st.title("🎥 YouTube Script Generator")
# Sidebar for model parameters
st.sidebar.title("Generation Parameters")
params = {
'max_length': st.sidebar.slider('Max Length', 64, 1024, 512),
'temperature': st.sidebar.slider('Temperature', 0.1, 1.0, 0.7),
'top_p': st.sidebar.slider('Top P', 0.1, 1.0, 0.95),
'top_k': st.sidebar.slider('Top K', 1, 100, 50),
'repetition_penalty': st.sidebar.slider('Repetition Penalty', 1.0, 2.0, 1.2)
}
# Load model and tokenizer
@st.cache_resource
def get_model():
return load_model()
try:
model, tokenizer = get_model()
# Tag input section
st.markdown("### Add Tags")
st.markdown("Enter tags separated by commas to generate a YouTube script")
# Create columns for tag input and generate button
col1, col2 = st.columns([3, 1])
with col1:
tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...")
with col2:
generate_button = st.button("Generate Script", type="primary")
# Generated script section
if generate_button and tags:
st.markdown("### Generated Script")
with st.spinner("Generating script..."):
script = generate_script(tags, model, tokenizer, params)
st.text_area("Your script:", value=script, height=400)
# Add download button
st.download_button(
label="Download Script",
data=script,
file_name="youtube_script.txt",
mime="text/plain"
)
elif generate_button and not tags:
st.warning("Please enter some tags first!")
except Exception as e:
st.error("Failed to initialize the application. Please check the logs for details.")
st.error(f"Error: {str(e)}")
if __name__ == "__main__":
main()