ginipick commited on
Commit
e1ec93f
·
verified ·
1 Parent(s): d57c019

Delete app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +0 -473
app-backup.py DELETED
@@ -1,473 +0,0 @@
1
- # 1. 먼저 로깅 설정
2
- import logging
3
- logging.basicConfig(level=logging.INFO)
4
- logger = logging.getLogger(__name__)
5
-
6
- # 2. 나머지 imports
7
- import os
8
- import time
9
- from datetime import datetime
10
- import gradio as gr
11
- import torch
12
- import requests
13
- from pathlib import Path
14
- import cv2
15
- from PIL import Image
16
- import json
17
- import spaces
18
- import torchaudio
19
- import tempfile
20
-
21
- try:
22
- import mmaudio
23
- except ImportError:
24
- os.system("pip install -e .")
25
- import mmaudio
26
-
27
- from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
28
- setup_eval_logging)
29
- from mmaudio.model.flow_matching import FlowMatching
30
- from mmaudio.model.networks import MMAudio, get_my_mmaudio
31
- from mmaudio.model.sequence_config import SequenceConfig
32
- from mmaudio.model.utils.features_utils import FeaturesUtils
33
- # 상단에 번역 모델 import 추가
34
- from transformers import pipeline
35
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
36
- # 3. API 설정
37
- CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
38
- REPLICATE_API_TOKEN = os.getenv("API_KEY")
39
-
40
- # 4. 오디오 모델 설정
41
- device = 'cuda'
42
- dtype = torch.bfloat16
43
-
44
- # 5. get_model 함수 정의
45
- def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
46
- seq_cfg = model.seq_cfg
47
-
48
- net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
49
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
50
- logger.info(f'Loaded weights from {model.model_path}')
51
-
52
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
53
- synchformer_ckpt=model.synchformer_ckpt,
54
- enable_conditions=True,
55
- mode=model.mode,
56
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
57
- need_vae_encoder=False)
58
- feature_utils = feature_utils.to(device, dtype).eval()
59
-
60
- return net, feature_utils, seq_cfg
61
-
62
- # 6. 모델 초기화
63
- model: ModelConfig = all_model_cfg['large_44k_v2']
64
- model.download_if_needed()
65
- output_dir = Path('./output/gradio')
66
-
67
- setup_eval_logging()
68
- net, feature_utils, seq_cfg = get_model()
69
-
70
- @spaces.GPU(duration=30) # 30초로 제한
71
- @torch.inference_mode()
72
- def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
73
- seed: int = -1, num_steps: int = 15,
74
- cfg_strength: float = 4.0, target_duration: float = 4.0):
75
- try:
76
- logger.info("Starting audio generation process")
77
- torch.cuda.empty_cache()
78
-
79
- rng = torch.Generator(device=device)
80
- if seed >= 0:
81
- rng.manual_seed(seed)
82
- else:
83
- rng.seed()
84
-
85
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
86
-
87
- # load_video 함수 호출 수정
88
- video_info = load_video(video_path, duration_sec=target_duration) # duration_sec 파라미터로 변경
89
-
90
- if video_info is None:
91
- logger.error("Failed to load video")
92
- return video_path
93
-
94
- clip_frames = video_info.clip_frames
95
- sync_frames = video_info.sync_frames
96
- actual_duration = video_info.duration_sec
97
-
98
- if clip_frames is None or sync_frames is None:
99
- logger.error("Failed to extract frames from video")
100
- return video_path
101
-
102
- # 메모리 최적화
103
- clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
104
- sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
105
-
106
- clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
107
- sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
108
-
109
- seq_cfg.duration = actual_duration
110
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
111
-
112
- logger.info("Generating audio...")
113
- with torch.cuda.amp.autocast():
114
- audios = generate(clip_frames,
115
- sync_frames,
116
- [prompt],
117
- negative_text=[negative_prompt],
118
- feature_utils=feature_utils,
119
- net=net,
120
- fm=fm,
121
- rng=rng,
122
- cfg_strength=cfg_strength)
123
-
124
- if audios is None:
125
- logger.error("Failed to generate audio")
126
- return video_path
127
-
128
- audio = audios.float().cpu()[0]
129
-
130
- output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
131
- logger.info(f"Creating final video with audio at {output_path}")
132
-
133
- make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
134
-
135
- torch.cuda.empty_cache()
136
-
137
- if not os.path.exists(output_path):
138
- logger.error("Failed to create output video")
139
- return video_path
140
-
141
- logger.info(f'Successfully saved video with audio to {output_path}')
142
- return output_path
143
-
144
- except Exception as e:
145
- logger.error(f"Error in video_to_audio: {str(e)}")
146
- torch.cuda.empty_cache()
147
- return video_path
148
-
149
- def upload_to_catbox(file_path):
150
- """catbox.moe API를 사용하여 파일 업로드"""
151
- try:
152
- logger.info(f"Preparing to upload file: {file_path}")
153
- url = "https://catbox.moe/user/api.php"
154
-
155
- mime_types = {
156
- '.jpg': 'image/jpeg',
157
- '.jpeg': 'image/jpeg',
158
- '.png': 'image/png',
159
- '.gif': 'image/gif',
160
- '.webp': 'image/webp',
161
- '.jfif': 'image/jpeg'
162
- }
163
-
164
- file_extension = Path(file_path).suffix.lower()
165
-
166
- if file_extension not in mime_types:
167
- try:
168
- img = Image.open(file_path)
169
- if img.mode != 'RGB':
170
- img = img.convert('RGB')
171
-
172
- new_path = file_path.rsplit('.', 1)[0] + '.png'
173
- img.save(new_path, 'PNG')
174
- file_path = new_path
175
- file_extension = '.png'
176
- logger.info(f"Converted image to PNG: {file_path}")
177
- except Exception as e:
178
- logger.error(f"Failed to convert image: {str(e)}")
179
- return None
180
-
181
- files = {
182
- 'fileToUpload': (
183
- os.path.basename(file_path),
184
- open(file_path, 'rb'),
185
- mime_types.get(file_extension, 'application/octet-stream')
186
- )
187
- }
188
-
189
- data = {
190
- 'reqtype': 'fileupload',
191
- 'userhash': CATBOX_USER_HASH
192
- }
193
-
194
- response = requests.post(url, files=files, data=data)
195
-
196
- if response.status_code == 200 and response.text.startswith('http'):
197
- file_url = response.text
198
- logger.info(f"File uploaded successfully: {file_url}")
199
- return file_url
200
- else:
201
- raise Exception(f"Upload failed: {response.text}")
202
-
203
- except Exception as e:
204
- logger.error(f"File upload error: {str(e)}")
205
- return None
206
- finally:
207
- if 'new_path' in locals() and os.path.exists(new_path):
208
- try:
209
- os.remove(new_path)
210
- except:
211
- pass
212
-
213
- def add_watermark(video_path):
214
- """OpenCV를 사용하여 비디오에 워터마크 추가"""
215
- try:
216
- cap = cv2.VideoCapture(video_path)
217
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
218
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
219
- fps = int(cap.get(cv2.CAP_PROP_FPS))
220
-
221
- text = "GiniGEN.AI"
222
- font = cv2.FONT_HERSHEY_SIMPLEX
223
- font_scale = height * 0.05 / 30
224
- thickness = 2
225
- color = (255, 255, 255)
226
-
227
- (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
228
- margin = int(height * 0.02)
229
- x_pos = width - text_width - margin
230
- y_pos = height - margin
231
-
232
- output_path = "watermarked_output.mp4"
233
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
234
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
235
-
236
- while cap.isOpened():
237
- ret, frame = cap.read()
238
- if not ret:
239
- break
240
- cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
241
- out.write(frame)
242
-
243
- cap.release()
244
- out.release()
245
-
246
- return output_path
247
-
248
- except Exception as e:
249
- logger.error(f"Error adding watermark: {str(e)}")
250
- return video_path
251
-
252
- def generate_video(image, prompt):
253
- logger.info("Starting video generation with API")
254
- try:
255
- API_KEY = os.getenv("API_KEY", "").strip()
256
- if not API_KEY:
257
- return "API key not properly configured"
258
-
259
- temp_dir = "temp_videos"
260
- os.makedirs(temp_dir, exist_ok=True)
261
-
262
- image_url = None
263
- if image:
264
- image_url = upload_to_catbox(image)
265
- if not image_url:
266
- return "Failed to upload image"
267
- logger.info(f"Input image URL: {image_url}")
268
-
269
- generation_url = "https://api.minimaxi.chat/v1/video_generation"
270
- headers = {
271
- 'authorization': f'Bearer {API_KEY}',
272
- 'Content-Type': 'application/json'
273
- }
274
-
275
- payload = {
276
- "model": "video-01",
277
- "prompt": prompt if prompt else "",
278
- "prompt_optimizer": True
279
- }
280
-
281
- if image_url:
282
- payload["first_frame_image"] = image_url
283
-
284
- logger.info(f"Sending request with payload: {payload}")
285
-
286
- response = requests.post(generation_url, headers=headers, json=payload)
287
-
288
- if not response.ok:
289
- error_msg = f"Failed to create video generation task: {response.text}"
290
- logger.error(error_msg)
291
- return error_msg
292
-
293
- response_data = response.json()
294
- task_id = response_data.get('task_id')
295
- if not task_id:
296
- return "Failed to get task ID from response"
297
-
298
- query_url = "https://api.minimaxi.chat/v1/query/video_generation"
299
- max_attempts = 30
300
- attempt = 0
301
-
302
- while attempt < max_attempts:
303
- time.sleep(10)
304
- query_response = requests.get(
305
- f"{query_url}?task_id={task_id}",
306
- headers={'authorization': f'Bearer {API_KEY}'}
307
- )
308
-
309
- if not query_response.ok:
310
- attempt += 1
311
- continue
312
-
313
- status_data = query_response.json()
314
- status = status_data.get('status')
315
-
316
- if status == 'Success':
317
- file_id = status_data.get('file_id')
318
- if not file_id:
319
- return "Failed to get file ID"
320
-
321
- retrieve_url = "https://api.minimaxi.chat/v1/files/retrieve"
322
- params = {'file_id': file_id}
323
-
324
- file_response = requests.get(
325
- retrieve_url,
326
- headers={'authorization': f'Bearer {API_KEY}'},
327
- params=params
328
- )
329
-
330
- if not file_response.ok:
331
- return "Failed to retrieve video file"
332
-
333
- try:
334
- file_data = file_response.json()
335
- download_url = file_data.get('file', {}).get('download_url')
336
- if not download_url:
337
- return "Failed to get download URL"
338
-
339
- result_info = {
340
- "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
341
- "input_image": image_url,
342
- "output_video_url": download_url,
343
- "prompt": prompt
344
- }
345
- logger.info(f"Video generation result: {json.dumps(result_info, indent=2)}")
346
-
347
- video_response = requests.get(download_url)
348
- if not video_response.ok:
349
- return "Failed to download video"
350
-
351
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
352
- output_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
353
-
354
- with open(output_path, 'wb') as f:
355
- f.write(video_response.content)
356
-
357
- final_path = add_watermark(output_path)
358
-
359
- # 오디오 처리 추가
360
- try:
361
- logger.info("Starting audio generation process")
362
- final_path_with_audio = video_to_audio(
363
- final_path,
364
- prompt=prompt,
365
- negative_prompt="music",
366
- seed=-1,
367
- num_steps=20,
368
- cfg_strength=4.5,
369
- target_duration=6.0
370
- )
371
-
372
- if final_path_with_audio != final_path:
373
- logger.info("Audio generation successful")
374
- try:
375
- if output_path != final_path:
376
- os.remove(output_path)
377
- if final_path != final_path_with_audio:
378
- os.remove(final_path)
379
- except Exception as e:
380
- logger.warning(f"Error cleaning up temporary files: {str(e)}")
381
-
382
- return final_path_with_audio
383
- else:
384
- logger.warning("Audio generation skipped, using original video")
385
- return final_path
386
-
387
- except Exception as e:
388
- logger.error(f"Error in audio processing: {str(e)}")
389
- return final_path # 오디오 처리 실패 시 워터마크만 된 비디오 반환
390
-
391
- except Exception as e:
392
- logger.error(f"Error processing video file: {str(e)}")
393
- return "Error processing video file"
394
-
395
- elif status == 'Fail':
396
- return "Video generation failed"
397
-
398
- attempt += 1
399
-
400
- return "Timeout waiting for video generation"
401
-
402
- except Exception as e:
403
- logger.error(f"Error in video generation: {str(e)}")
404
- return f"Error in video generation process: {str(e)}"
405
-
406
- css = """
407
- footer {
408
- visibility: hidden;
409
- }
410
- .gradio-container {max-width: 1200px !important}
411
- """
412
-
413
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
414
-
415
-
416
- with gr.Row():
417
- with gr.Column(scale=3):
418
- video_prompt = gr.Textbox(
419
- label="Video Description",
420
- placeholder="Enter video description...",
421
- lines=3
422
- )
423
- upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
424
- video_generate_btn = gr.Button("🎬 Generate Video")
425
-
426
- with gr.Column(scale=4):
427
- video_output = gr.Video(label="Generated Video")
428
-
429
-
430
-
431
-
432
- # process_and_generate_video 함수 수정
433
- def process_and_generate_video(image, prompt):
434
- if image is None:
435
- return "Please upload an image"
436
-
437
- try:
438
- # 한글 프롬프트 감지 및 번역
439
- contains_korean = any(ord('가') <= ord(char) <= ord('힣') for char in prompt)
440
- if contains_korean:
441
- translated = translator(prompt)[0]['translation_text']
442
- logger.info(f"Translated prompt from '{prompt}' to '{translated}'")
443
- prompt = translated
444
-
445
- img = Image.open(image)
446
- if img.mode != 'RGB':
447
- img = img.convert('RGB')
448
-
449
- temp_path = f"temp_{int(time.time())}.png"
450
- img.save(temp_path, 'PNG')
451
-
452
- result = generate_video(temp_path, prompt)
453
-
454
- try:
455
- os.remove(temp_path)
456
- except:
457
- pass
458
-
459
- return result
460
-
461
- except Exception as e:
462
- logger.error(f"Error processing image: {str(e)}")
463
- return "Error processing image"
464
-
465
-
466
- video_generate_btn.click(
467
- process_and_generate_video,
468
- inputs=[upload_image, video_prompt],
469
- outputs=video_output
470
- )
471
-
472
- if __name__ == "__main__":
473
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)