# 1. 먼저 로깅 설정 import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 2. 나머지 imports import os import time from datetime import datetime import gradio as gr import torch import requests from pathlib import Path import cv2 from PIL import Image import json import spaces import torchaudio import tempfile try: import mmaudio except ImportError: os.system("pip install -e .") import mmaudio from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # 상단에 번역 모델 import 추가 from transformers import pipeline translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") # 3. API 설정 CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489" REPLICATE_API_TOKEN = os.getenv("API_KEY") # 4. 오디오 모델 설정 device = 'cuda' dtype = torch.bfloat16 # 5. get_model 함수 정의 def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: seq_cfg = model.seq_cfg net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) logger.info(f'Loaded weights from {model.model_path}') feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, synchformer_ckpt=model.synchformer_ckpt, enable_conditions=True, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False) feature_utils = feature_utils.to(device, dtype).eval() return net, feature_utils, seq_cfg # 6. 모델 초기화 model: ModelConfig = all_model_cfg['large_44k_v2'] model.download_if_needed() output_dir = Path('./output/gradio') setup_eval_logging() net, feature_utils, seq_cfg = get_model() @spaces.GPU(duration=30) # 30초로 제한 @torch.inference_mode() def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music", seed: int = -1, num_steps: int = 15, cfg_strength: float = 4.0, target_duration: float = 4.0): try: logger.info("Starting audio generation process") torch.cuda.empty_cache() rng = torch.Generator(device=device) if seed >= 0: rng.manual_seed(seed) else: rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) # load_video 함수 호출 수정 video_info = load_video(video_path, duration_sec=target_duration) # duration_sec 파라미터로 변경 if video_info is None: logger.error("Failed to load video") return video_path clip_frames = video_info.clip_frames sync_frames = video_info.sync_frames actual_duration = video_info.duration_sec if clip_frames is None or sync_frames is None: logger.error("Failed to extract frames from video") return video_path # 메모리 최적화 clip_frames = clip_frames[:int(actual_duration * video_info.fps)] sync_frames = sync_frames[:int(actual_duration * video_info.fps)] clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16) sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16) seq_cfg.duration = actual_duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) logger.info("Generating audio...") with torch.cuda.amp.autocast(): audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) if audios is None: logger.error("Failed to generate audio") return video_path audio = audios.float().cpu()[0] output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name logger.info(f"Creating final video with audio at {output_path}") make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate) torch.cuda.empty_cache() if not os.path.exists(output_path): logger.error("Failed to create output video") return video_path logger.info(f'Successfully saved video with audio to {output_path}') return output_path except Exception as e: logger.error(f"Error in video_to_audio: {str(e)}") torch.cuda.empty_cache() return video_path def upload_to_catbox(file_path): """catbox.moe API를 사용하여 파일 업로드""" try: logger.info(f"Preparing to upload file: {file_path}") url = "https://catbox.moe/user/api.php" mime_types = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.webp': 'image/webp', '.jfif': 'image/jpeg' } file_extension = Path(file_path).suffix.lower() if file_extension not in mime_types: try: img = Image.open(file_path) if img.mode != 'RGB': img = img.convert('RGB') new_path = file_path.rsplit('.', 1)[0] + '.png' img.save(new_path, 'PNG') file_path = new_path file_extension = '.png' logger.info(f"Converted image to PNG: {file_path}") except Exception as e: logger.error(f"Failed to convert image: {str(e)}") return None files = { 'fileToUpload': ( os.path.basename(file_path), open(file_path, 'rb'), mime_types.get(file_extension, 'application/octet-stream') ) } data = { 'reqtype': 'fileupload', 'userhash': CATBOX_USER_HASH } response = requests.post(url, files=files, data=data) if response.status_code == 200 and response.text.startswith('http'): file_url = response.text logger.info(f"File uploaded successfully: {file_url}") return file_url else: raise Exception(f"Upload failed: {response.text}") except Exception as e: logger.error(f"File upload error: {str(e)}") return None finally: if 'new_path' in locals() and os.path.exists(new_path): try: os.remove(new_path) except: pass def add_watermark(video_path): """OpenCV를 사용하여 비디오에 워터마크 추가""" try: cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) text = "GiniGEN.AI" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = height * 0.05 / 30 thickness = 2 color = (255, 255, 255) (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) margin = int(height * 0.02) x_pos = width - text_width - margin y_pos = height - margin output_path = "watermarked_output.mp4" fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) out.write(frame) cap.release() out.release() return output_path except Exception as e: logger.error(f"Error adding watermark: {str(e)}") return video_path def generate_video(image, prompt): logger.info("Starting video generation with API") try: API_KEY = os.getenv("API_KEY", "").strip() if not API_KEY: return "API key not properly configured" temp_dir = "temp_videos" os.makedirs(temp_dir, exist_ok=True) image_url = None if image: image_url = upload_to_catbox(image) if not image_url: return "Failed to upload image" logger.info(f"Input image URL: {image_url}") generation_url = "https://api.minimaxi.chat/v1/video_generation" headers = { 'authorization': f'Bearer {API_KEY}', 'Content-Type': 'application/json' } payload = { "model": "video-01", "prompt": prompt if prompt else "", "prompt_optimizer": True } if image_url: payload["first_frame_image"] = image_url logger.info(f"Sending request with payload: {payload}") response = requests.post(generation_url, headers=headers, json=payload) if not response.ok: error_msg = f"Failed to create video generation task: {response.text}" logger.error(error_msg) return error_msg response_data = response.json() task_id = response_data.get('task_id') if not task_id: return "Failed to get task ID from response" query_url = "https://api.minimaxi.chat/v1/query/video_generation" max_attempts = 30 attempt = 0 while attempt < max_attempts: time.sleep(10) query_response = requests.get( f"{query_url}?task_id={task_id}", headers={'authorization': f'Bearer {API_KEY}'} ) if not query_response.ok: attempt += 1 continue status_data = query_response.json() status = status_data.get('status') if status == 'Success': file_id = status_data.get('file_id') if not file_id: return "Failed to get file ID" retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve" params = {'file_id': file_id} file_response = requests.get( retrieve_url, headers={'authorization': f'Bearer {API_KEY}'}, params=params ) if not file_response.ok: return "Failed to retrieve video file" try: file_data = file_response.json() download_url = file_data.get('file', {}).get('download_url') if not download_url: return "Failed to get download URL" result_info = { "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), "input_image": image_url, "output_video_url": download_url, "prompt": prompt } logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}") video_response = requests.get(download_url) if not video_response.ok: return "Failed to download video" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4") with open(output_path, 'wb') as f: f.write(video_response.content) final_path = add_watermark(output_path) # 오디오 처리 추가 try: logger.info("Starting audio generation process") final_path_with_audio = video_to_audio( final_path, prompt=prompt, negative_prompt="music", seed=-1, num_steps=20, cfg_strength=4.5, target_duration=6.0 ) if final_path_with_audio != final_path: logger.info("Audio generation successful") try: if output_path != final_path: os.remove(output_path) if final_path != final_path_with_audio: os.remove(final_path) except Exception as e: logger.warning(f"Error cleaning up temporary files: {str(e)}") return final_path_with_audio else: logger.warning("Audio generation skipped, using original video") return final_path except Exception as e: logger.error(f"Error in audio processing: {str(e)}") return final_path # 오디오 처리 실패 시 워터마크만 된 비디오 반환 except Exception as e: logger.error(f"Error processing video file: {str(e)}") return "Error processing video file" elif status == 'Fail': return "Video generation failed" attempt += 1 return "Timeout waiting for video generation" except Exception as e: logger.error(f"Error in video generation: {str(e)}") return f"Error in video generation process: {str(e)}" css = """ footer { visibility: hidden; } .gradio-container {max-width: 1200px !important} """ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: with gr.Row(): with gr.Column(scale=3): video_prompt = gr.Textbox( label="Video Description", placeholder="Enter video description...", lines=3 ) upload_image = gr.Image(type="filepath", label="Upload First Frame Image") video_generate_btn = gr.Button("🎬 Generate Video") with gr.Column(scale=4): video_output = gr.Video(label="Generated Video") # process_and_generate_video 함수 수정 def process_and_generate_video(image, prompt): if image is None: return "Please upload an image" try: # 한글 프롬프트 감지 및 번역 contains_korean = any(ord('가') <= ord(char) <= ord('힣') for char in prompt) if contains_korean: translated = translator(prompt)[0]['translation_text'] logger.info(f"Translated prompt from '{prompt}' to '{translated}'") prompt = translated img = Image.open(image) if img.mode != 'RGB': img = img.convert('RGB') temp_path = f"temp_{int(time.time())}.png" img.save(temp_path, 'PNG') result = generate_video(temp_path, prompt) try: os.remove(temp_path) except: pass return result except Exception as e: logger.error(f"Error processing image: {str(e)}") return "Error processing image" video_generate_btn.click( process_and_generate_video, inputs=[upload_image, video_prompt], outputs=video_output ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)