Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer | |
from peft import AutoPeftModelForCausalLM | |
import torch | |
import re | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
import os | |
# Set cache directory and get token | |
os.environ['HF_HOME'] = '/app/cache' | |
hf_token = os.getenv('HF_TOKEN') | |
class StopWordCriteria(StoppingCriteria): | |
def __init__(self, tokenizer, stop_word): | |
self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False) | |
def __call__(self, input_ids, scores, **kwargs): | |
if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id: | |
return True | |
return False | |
def load_model(): | |
try: | |
# Ensure cache directory exists | |
cache_dir = '/app/cache' | |
os.makedirs(cache_dir, exist_ok=True) | |
# Check for HF token | |
if not hf_token: | |
st.warning("HuggingFace token not found. Some models may not be accessible.") | |
# Check CUDA availability | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
st.success(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
else: | |
device = torch.device("cpu") | |
st.warning("CUDA is not available. Using CPU.") | |
# Fine-tuned model for generating scripts | |
model_name = "Sidharthan/gemma2_scripter" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
token=hf_token, | |
cache_dir=cache_dir | |
) | |
except Exception as e: | |
st.error(f"Error loading tokenizer: {str(e)}") | |
if "401" in str(e): | |
st.error("Authentication error. Please check your HuggingFace token.") | |
raise e | |
try: | |
# Load model with appropriate device settings | |
model = AutoPeftModelForCausalLM.from_pretrained( | |
model_name, | |
device_map=None, # We'll handle device placement manually | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
token=hf_token, | |
cache_dir=cache_dir | |
) | |
# Move model to device | |
model = model.to(device) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
if "401" in str(e): | |
st.error("Authentication error. Please check your HuggingFace token.") | |
elif "disk space" in str(e).lower(): | |
st.error("Insufficient disk space in cache directory.") | |
raise e | |
except Exception as e: | |
st.error(f"General error during model loading: {str(e)}") | |
raise e | |
def generate_script(tags, model, tokenizer, params): | |
device = next(model.parameters()).device | |
# Create prompt with tags | |
prompt = f"<bos><start_of_turn>keywords\n{tags}<end_of_turn>\n<start_of_turn>script\n" | |
# Tokenize and move to device | |
inputs = tokenizer(prompt, return_tensors='pt') | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
stop_word = 'script' | |
stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)]) | |
try: | |
outputs = model.generate( | |
**inputs, | |
max_length=params['max_length'], | |
do_sample=True, | |
temperature=params['temperature'], | |
top_p=params['top_p'], | |
top_k=params['top_k'], | |
repetition_penalty=params['repetition_penalty'], | |
num_return_sequences=1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
stopping_criteria=stopping_criteria | |
) | |
# Move outputs back to CPU for decoding | |
outputs = outputs.cpu() | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up response | |
response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL) | |
response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip() | |
return response | |
except RuntimeError as e: | |
if "out of memory" in str(e): | |
st.error("GPU out of memory error. Try reducing max_length or using CPU.") | |
return "Error: GPU out of memory" | |
else: | |
st.error(f"Error during generation: {str(e)}") | |
return f"Error during generation: {str(e)}" | |
def main(): | |
st.title("🎥 YouTube Script Generator") | |
# Sidebar for model parameters | |
st.sidebar.title("Generation Parameters") | |
params = { | |
'max_length': st.sidebar.slider('Max Length', 64, 1024, 512), | |
'temperature': st.sidebar.slider('Temperature', 0.1, 1.0, 0.7), | |
'top_p': st.sidebar.slider('Top P', 0.1, 1.0, 0.95), | |
'top_k': st.sidebar.slider('Top K', 1, 100, 50), | |
'repetition_penalty': st.sidebar.slider('Repetition Penalty', 1.0, 2.0, 1.2) | |
} | |
# Load model and tokenizer | |
def get_model(): | |
return load_model() | |
try: | |
model, tokenizer = get_model() | |
# Tag input section | |
st.markdown("### Add Tags") | |
st.markdown("Enter tags separated by commas to generate a YouTube script") | |
# Create columns for tag input and generate button | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...") | |
with col2: | |
generate_button = st.button("Generate Script", type="primary") | |
# Generated script section | |
if generate_button and tags: | |
st.markdown("### Generated Script") | |
with st.spinner("Generating script..."): | |
script = generate_script(tags, model, tokenizer, params) | |
st.text_area("Your script:", value=script, height=400) | |
# Add download button | |
st.download_button( | |
label="Download Script", | |
data=script, | |
file_name="youtube_script.txt", | |
mime="text/plain" | |
) | |
elif generate_button and not tags: | |
st.warning("Please enter some tags first!") | |
except Exception as e: | |
st.error("Failed to initialize the application. Please check the logs for details.") | |
st.error(f"Error: {str(e)}") | |
if __name__ == "__main__": | |
main() |