Spaces:
Runtime error
Runtime error
# 1. ๋จผ์ ๋ก๊น ์ค์ | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# 2. ๋๋จธ์ง imports | |
import os | |
import time | |
from datetime import datetime | |
import gradio as gr | |
import torch | |
import requests | |
from pathlib import Path | |
import cv2 | |
from PIL import Image | |
import json | |
import spaces | |
import torchaudio | |
import tempfile | |
try: | |
import mmaudio | |
except ImportError: | |
os.system("pip install -e .") | |
import mmaudio | |
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, | |
setup_eval_logging) | |
from mmaudio.model.flow_matching import FlowMatching | |
from mmaudio.model.networks import MMAudio, get_my_mmaudio | |
from mmaudio.model.sequence_config import SequenceConfig | |
from mmaudio.model.utils.features_utils import FeaturesUtils | |
# ์๋จ์ ๋ฒ์ญ ๋ชจ๋ธ import ์ถ๊ฐ | |
from transformers import pipeline | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
# 3. API ์ค์ | |
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489" | |
REPLICATE_API_TOKEN = os.getenv("API_KEY") | |
# 4. ์ค๋์ค ๋ชจ๋ธ ์ค์ | |
device = 'cuda' | |
dtype = torch.bfloat16 | |
# 5. get_model ํจ์ ์ ์ | |
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: | |
seq_cfg = model.seq_cfg | |
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() | |
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) | |
logger.info(f'Loaded weights from {model.model_path}') | |
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, | |
synchformer_ckpt=model.synchformer_ckpt, | |
enable_conditions=True, | |
mode=model.mode, | |
bigvgan_vocoder_ckpt=model.bigvgan_16k_path, | |
need_vae_encoder=False) | |
feature_utils = feature_utils.to(device, dtype).eval() | |
return net, feature_utils, seq_cfg | |
# 6. ๋ชจ๋ธ ์ด๊ธฐํ | |
model: ModelConfig = all_model_cfg['large_44k_v2'] | |
model.download_if_needed() | |
output_dir = Path('./output/gradio') | |
setup_eval_logging() | |
net, feature_utils, seq_cfg = get_model() | |
# 30์ด๋ก ์ ํ | |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music", | |
seed: int = -1, num_steps: int = 15, | |
cfg_strength: float = 4.0, target_duration: float = 4.0): | |
try: | |
logger.info("Starting audio generation process") | |
torch.cuda.empty_cache() | |
rng = torch.Generator(device=device) | |
if seed >= 0: | |
rng.manual_seed(seed) | |
else: | |
rng.seed() | |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) | |
# load_video ํจ์ ํธ์ถ ์์ | |
video_info = load_video(video_path, duration_sec=target_duration) # duration_sec ํ๋ผ๋ฏธํฐ๋ก ๋ณ๊ฒฝ | |
if video_info is None: | |
logger.error("Failed to load video") | |
return video_path | |
clip_frames = video_info.clip_frames | |
sync_frames = video_info.sync_frames | |
actual_duration = video_info.duration_sec | |
if clip_frames is None or sync_frames is None: | |
logger.error("Failed to extract frames from video") | |
return video_path | |
# ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ | |
clip_frames = clip_frames[:int(actual_duration * video_info.fps)] | |
sync_frames = sync_frames[:int(actual_duration * video_info.fps)] | |
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16) | |
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16) | |
seq_cfg.duration = actual_duration | |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
logger.info("Generating audio...") | |
with torch.cuda.amp.autocast(): | |
audios = generate(clip_frames, | |
sync_frames, | |
[prompt], | |
negative_text=[negative_prompt], | |
feature_utils=feature_utils, | |
net=net, | |
fm=fm, | |
rng=rng, | |
cfg_strength=cfg_strength) | |
if audios is None: | |
logger.error("Failed to generate audio") | |
return video_path | |
audio = audios.float().cpu()[0] | |
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name | |
logger.info(f"Creating final video with audio at {output_path}") | |
make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate) | |
torch.cuda.empty_cache() | |
if not os.path.exists(output_path): | |
logger.error("Failed to create output video") | |
return video_path | |
logger.info(f'Successfully saved video with audio to {output_path}') | |
return output_path | |
except Exception as e: | |
logger.error(f"Error in video_to_audio: {str(e)}") | |
torch.cuda.empty_cache() | |
return video_path | |
def upload_to_catbox(file_path): | |
"""catbox.moe API๋ฅผ ์ฌ์ฉํ์ฌ ํ์ผ ์ ๋ก๋""" | |
try: | |
logger.info(f"Preparing to upload file: {file_path}") | |
url = "https://catbox.moe/user/api.php" | |
mime_types = { | |
'.jpg': 'image/jpeg', | |
'.jpeg': 'image/jpeg', | |
'.png': 'image/png', | |
'.gif': 'image/gif', | |
'.webp': 'image/webp', | |
'.jfif': 'image/jpeg' | |
} | |
file_extension = Path(file_path).suffix.lower() | |
if file_extension not in mime_types: | |
try: | |
img = Image.open(file_path) | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
new_path = file_path.rsplit('.', 1)[0] + '.png' | |
img.save(new_path, 'PNG') | |
file_path = new_path | |
file_extension = '.png' | |
logger.info(f"Converted image to PNG: {file_path}") | |
except Exception as e: | |
logger.error(f"Failed to convert image: {str(e)}") | |
return None | |
files = { | |
'fileToUpload': ( | |
os.path.basename(file_path), | |
open(file_path, 'rb'), | |
mime_types.get(file_extension, 'application/octet-stream') | |
) | |
} | |
data = { | |
'reqtype': 'fileupload', | |
'userhash': CATBOX_USER_HASH | |
} | |
response = requests.post(url, files=files, data=data) | |
if response.status_code == 200 and response.text.startswith('http'): | |
file_url = response.text | |
logger.info(f"File uploaded successfully: {file_url}") | |
return file_url | |
else: | |
raise Exception(f"Upload failed: {response.text}") | |
except Exception as e: | |
logger.error(f"File upload error: {str(e)}") | |
return None | |
finally: | |
if 'new_path' in locals() and os.path.exists(new_path): | |
try: | |
os.remove(new_path) | |
except: | |
pass | |
def add_watermark(video_path): | |
"""OpenCV๋ฅผ ์ฌ์ฉํ์ฌ ๋น๋์ค์ ์ํฐ๋งํฌ ์ถ๊ฐ""" | |
try: | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
text = "GiniGEN.AI" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = height * 0.05 / 30 | |
thickness = 2 | |
color = (255, 255, 255) | |
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) | |
margin = int(height * 0.02) | |
x_pos = width - text_width - margin | |
y_pos = height - margin | |
output_path = "watermarked_output.mp4" | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) | |
out.write(frame) | |
cap.release() | |
out.release() | |
return output_path | |
except Exception as e: | |
logger.error(f"Error adding watermark: {str(e)}") | |
return video_path | |
def generate_video(image, prompt): | |
logger.info("Starting video generation with API") | |
try: | |
API_KEY = os.getenv("API_KEY", "").strip() | |
if not API_KEY: | |
return "API key not properly configured" | |
temp_dir = "temp_videos" | |
os.makedirs(temp_dir, exist_ok=True) | |
image_url = None | |
if image: | |
image_url = upload_to_catbox(image) | |
if not image_url: | |
return "Failed to upload image" | |
logger.info(f"Input image URL: {image_url}") | |
generation_url = "https://api.minimaxi.chat/v1/video_generation" | |
headers = { | |
'authorization': f'Bearer {API_KEY}', | |
'Content-Type': 'application/json' | |
} | |
payload = { | |
"model": "video-01", | |
"prompt": prompt if prompt else "", | |
"prompt_optimizer": True | |
} | |
if image_url: | |
payload["first_frame_image"] = image_url | |
logger.info(f"Sending request with payload: {payload}") | |
response = requests.post(generation_url, headers=headers, json=payload) | |
if not response.ok: | |
error_msg = f"Failed to create video generation task: {response.text}" | |
logger.error(error_msg) | |
return error_msg | |
response_data = response.json() | |
task_id = response_data.get('task_id') | |
if not task_id: | |
return "Failed to get task ID from response" | |
query_url = "https://api.minimaxi.chat/v1/query/video_generation" | |
max_attempts = 30 | |
attempt = 0 | |
while attempt < max_attempts: | |
time.sleep(10) | |
query_response = requests.get( | |
f"{query_url}?task_id={task_id}", | |
headers={'authorization': f'Bearer {API_KEY}'} | |
) | |
if not query_response.ok: | |
attempt += 1 | |
continue | |
status_data = query_response.json() | |
status = status_data.get('status') | |
if status == 'Success': | |
file_id = status_data.get('file_id') | |
if not file_id: | |
return "Failed to get file ID" | |
retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve" | |
params = {'file_id': file_id} | |
file_response = requests.get( | |
retrieve_url, | |
headers={'authorization': f'Bearer {API_KEY}'}, | |
params=params | |
) | |
if not file_response.ok: | |
return "Failed to retrieve video file" | |
try: | |
file_data = file_response.json() | |
download_url = file_data.get('file', {}).get('download_url') | |
if not download_url: | |
return "Failed to get download URL" | |
result_info = { | |
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), | |
"input_image": image_url, | |
"output_video_url": download_url, | |
"prompt": prompt | |
} | |
logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}") | |
video_response = requests.get(download_url) | |
if not video_response.ok: | |
return "Failed to download video" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4") | |
with open(output_path, 'wb') as f: | |
f.write(video_response.content) | |
final_path = add_watermark(output_path) | |
# ์ค๋์ค ์ฒ๋ฆฌ ์ถ๊ฐ | |
try: | |
logger.info("Starting audio generation process") | |
final_path_with_audio = video_to_audio( | |
final_path, | |
prompt=prompt, | |
negative_prompt="music", | |
seed=-1, | |
num_steps=20, | |
cfg_strength=4.5, | |
target_duration=6.0 | |
) | |
if final_path_with_audio != final_path: | |
logger.info("Audio generation successful") | |
try: | |
if output_path != final_path: | |
os.remove(output_path) | |
if final_path != final_path_with_audio: | |
os.remove(final_path) | |
except Exception as e: | |
logger.warning(f"Error cleaning up temporary files: {str(e)}") | |
return final_path_with_audio | |
else: | |
logger.warning("Audio generation skipped, using original video") | |
return final_path | |
except Exception as e: | |
logger.error(f"Error in audio processing: {str(e)}") | |
return final_path # ์ค๋์ค ์ฒ๋ฆฌ ์คํจ ์ ์ํฐ๋งํฌ๋ง ๋ ๋น๋์ค ๋ฐํ | |
except Exception as e: | |
logger.error(f"Error processing video file: {str(e)}") | |
return "Error processing video file" | |
elif status == 'Fail': | |
return "Video generation failed" | |
attempt += 1 | |
return "Timeout waiting for video generation" | |
except Exception as e: | |
logger.error(f"Error in video generation: {str(e)}") | |
return f"Error in video generation process: {str(e)}" | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
.gradio-container {max-width: 1200px !important} | |
""" | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
with gr.Row(): | |
with gr.Column(scale=3): | |
video_prompt = gr.Textbox( | |
label="Video Description", | |
placeholder="Enter video description...", | |
lines=3 | |
) | |
upload_image = gr.Image(type="filepath", label="Upload First Frame Image") | |
video_generate_btn = gr.Button("๐ฌ Generate Video") | |
with gr.Column(scale=4): | |
video_output = gr.Video(label="Generated Video") | |
# process_and_generate_video ํจ์ ์์ | |
def process_and_generate_video(image, prompt): | |
if image is None: | |
return "Please upload an image" | |
try: | |
# ํ๊ธ ํ๋กฌํํธ ๊ฐ์ง ๋ฐ ๋ฒ์ญ | |
contains_korean = any(ord('๊ฐ') <= ord(char) <= ord('ํฃ') for char in prompt) | |
if contains_korean: | |
translated = translator(prompt)[0]['translation_text'] | |
logger.info(f"Translated prompt from '{prompt}' to '{translated}'") | |
prompt = translated | |
img = Image.open(image) | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
temp_path = f"temp_{int(time.time())}.png" | |
img.save(temp_path, 'PNG') | |
result = generate_video(temp_path, prompt) | |
try: | |
os.remove(temp_path) | |
except: | |
pass | |
return result | |
except Exception as e: | |
logger.error(f"Error processing image: {str(e)}") | |
return "Error processing image" | |
video_generate_btn.click( | |
process_and_generate_video, | |
inputs=[upload_image, video_prompt], | |
outputs=video_output | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |