Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM | |
import matplotlib.pyplot as plt | |
# Load sentiment and generation models as previously set up | |
tokenizer_sentiment = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment") | |
model_sentiment = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment") | |
sentiment_classifier = pipeline("sentiment-analysis", model=model_sentiment, tokenizer=tokenizer_sentiment) | |
tokenizer_gpt = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") | |
model_gpt = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B") | |
# Updated generate_text function with padding token fix | |
def generate_text(prompt): | |
# Add '[PAD]' as a padding token if not set | |
if tokenizer_gpt.pad_token is None: | |
tokenizer_gpt.add_special_tokens({'pad_token': '[PAD]'}) | |
# Prepare the input tensors with attention_mask and padding | |
inputs = tokenizer_gpt(prompt, return_tensors="pt", padding=True, truncation=True, max_length=50) | |
# Generate text using max_new_tokens instead of max_length | |
outputs = model_gpt.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=50, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
pad_token_id=tokenizer_gpt.pad_token_id # Ensure pad_token_id is set | |
) | |
generated_text = tokenizer_gpt.decode(outputs[0], skip_special_tokens=True) | |
return generated_text.strip() | |
# Function to analyze each line of the script | |
def analyze_script(script): | |
global all_scores, descriptions, music_cues | |
lines = script.strip().split("\n") | |
all_scores = [] | |
descriptions = [] | |
music_cues = [] | |
analysis_results = [] | |
for i, line in enumerate(lines): | |
result = sentiment_classifier(line)[0] | |
sentiment = result['label'] | |
score = result['score'] | |
description_prompt = f"Describe a scene with the sentiment '{sentiment}' for the line: '{line}'" | |
description = generate_text(description_prompt) | |
music_cue_prompt = f"Suggest music elements (like tempo, key, and instrumentation) that would fit a scene with the sentiment '{sentiment}': '{line}'" | |
music_cue = generate_text(music_cue_prompt) | |
all_scores.append(score) | |
descriptions.append(description) | |
music_cues.append(music_cue) | |
analysis_results.append( | |
{ | |
"Line": f"Line {i + 1}: {line}", | |
"Sentiment": f"{sentiment} (Score: {round(score, 2)})", | |
"Description Suggestion": description, | |
"Music Cue": music_cue | |
} | |
) | |
graph_path = generate_script_graph() | |
return analysis_results, graph_path | |
# Generate the emotional arc graph with music cues for the entire script | |
def generate_script_graph(): | |
plt.figure(figsize=(12, 6)) | |
plt.plot(all_scores, marker='o', linestyle='-', color='b', label='Sentiment Intensity') | |
for i, score in enumerate(all_scores): | |
plt.text(i, score, music_cues[i], fontsize=8, ha='right', rotation=45) | |
plt.title('Emotional and Musical Arc for Entire Script') | |
plt.xlabel('Script Lines (Accumulative)') | |
plt.ylabel('Sentiment Intensity') | |
plt.legend() | |
plt.tight_layout() | |
plot_path = "script_emotional_arc.png" | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path | |
# Custom Gradio component to display dashboard results with icons | |
def format_dashboard(results): | |
formatted_results = "" | |
for result in results: | |
formatted_results += f""" | |
<div class="dashboard-box"> | |
<p><img src="https://geist-ui.dev/icons/activity.svg" class="dashboard-icon" alt="Line icon"> <strong>{result['Line']}</strong></p> | |
<p><img src="https://geist-ui.dev/icons/bar-chart.svg" class="dashboard-icon" alt="Sentiment icon"> <strong>Sentiment:</strong> {result['Sentiment']}</p> | |
<p><img src="https://geist-ui.dev/icons/eye.svg" class="dashboard-icon" alt="Description icon"> <strong>Description Suggestion:</strong> {result['Description Suggestion']}</p> | |
<p><img src="https://geist-ui.dev/icons/music.svg" class="dashboard-icon" alt="Music Cue icon"> <strong>Music Cue:</strong> {result['Music Cue']}</p> | |
</div> | |
""" | |
return formatted_results | |
# Gradio interface to analyze script and display the dashboard | |
with gr.Blocks(css="custom.css") as interface: | |
gr.Markdown("## Script Sentiment and Music Cue Analyzer", elem_id="title") | |
gr.Markdown("Enter your script line-by-line, and this tool will analyze sentiment, generate scene descriptions, suggest music cues, and show an emotional and musical arc.", elem_id="description") | |
script_input = gr.Textbox(lines=10, placeholder="Enter your script here, one line per thought or dialogue.", label="Script") | |
display_dashboard_button = gr.Button("Analyze Script") | |
output_dashboard = gr.HTML(label="Dashboard Results") | |
output_graph = gr.Image(label="Emotional and Musical Arc for Entire Script") | |
def display_dashboard(script): | |
analysis_results, graph_path = analyze_script(script) | |
dashboard_content = format_dashboard(analysis_results) | |
return dashboard_content, graph_path | |
display_dashboard_button.click(display_dashboard, inputs=script_input, outputs=[output_dashboard, output_graph]) | |
interface.launch() | |