GPTfree api commited on
Commit
7038078
·
verified ·
1 Parent(s): a06c4fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -57
app.py CHANGED
@@ -1,60 +1,82 @@
 
 
 
1
  import gradio as gr
2
- from gradio_client import Client, handle_file
3
-
4
- # クライアント設定
5
- client = Client("Plachta/Seed-VC")
6
-
7
- def process_audio(
8
- source,
9
- target,
10
- diffusion_steps=25,
11
- length_adjust=1,
12
- inference_cfg_rate=0.7,
13
- f0_condition=False,
14
- auto_f0_adjust=True,
15
- pitch_shift=0
16
- ):
17
- # API呼び出し
18
- result = client.predict(
19
- source=handle_file(source.name),
20
- target=handle_file(target.name),
21
- diffusion_steps=diffusion_steps,
22
- length_adjust=length_adjust,
23
- inference_cfg_rate=inference_cfg_rate,
24
- f0_condition=f0_condition,
25
- auto_f0_adjust=auto_f0_adjust,
26
- pitch_shift=pitch_shift,
27
- api_name="/predict"
28
- )
29
- return result
30
-
31
- # Gradioインターフェース作成
32
- with gr.Blocks() as demo:
33
- gr.Markdown("# Audio Transformation with Seed-VC")
34
-
35
- with gr.Row():
36
- source_audio = gr.Audio(label="Source Audio", type="file")
37
- target_audio = gr.Audio(label="Reference Audio", type="file")
38
-
39
- diffusion_steps = gr.Slider(1, 50, value=25, label="Diffusion Steps")
40
- length_adjust = gr.Slider(0.5, 2, value=1, label="Length Adjust")
41
- inference_cfg_rate = gr.Slider(0.1, 1.0, value=0.7, label="Inference CFG Rate")
42
- f0_condition = gr.Checkbox(label="Use F0 conditioned model")
43
- auto_f0_adjust = gr.Checkbox(label="Auto F0 adjust", value=True)
44
- pitch_shift = gr.Slider(-12, 12, value=0, label="Pitch shift")
45
-
46
- output_stream = gr.Audio(label="Stream Output Audio")
47
- output_full = gr.Audio(label="Full Output Audio")
48
-
49
- run_button = gr.Button("Transform Audio")
50
-
51
- run_button.click(
52
- process_audio,
53
- inputs=[
54
- source_audio, target_audio, diffusion_steps, length_adjust,
55
- inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift
56
- ],
57
- outputs=[output_stream, output_full]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  )
 
 
59
 
60
- demo.launch()
 
 
 
1
+ import os
2
+ import torch
3
+ import requests
4
  import gradio as gr
5
+ import torchaudio
6
+
7
+ # モデルのダウンロード関数
8
+ def download_model(model_url, output_path):
9
+ if not os.path.exists(output_path):
10
+ print(f"Downloading model from {model_url} to {output_path}")
11
+ response = requests.get(model_url, stream=True)
12
+ if response.status_code == 200:
13
+ with open(output_path, 'wb') as f:
14
+ f.write(response.content)
15
+ print("Model downloaded successfully.")
16
+ else:
17
+ raise ValueError(f"Failed to download model: {response.status_code}")
18
+ else:
19
+ print(f"Model already exists at {output_path}")
20
+
21
+ # モデルロード用の関数
22
+ def load_model(model_path):
23
+ if not os.path.exists(model_path):
24
+ raise FileNotFoundError(f"Model file not found: {model_path}")
25
+ print(f"Loading model from {model_path}")
26
+ try:
27
+ model = torch.load(model_path, map_location=torch.device('cpu'))
28
+ model.eval()
29
+ return model
30
+ except Exception as e:
31
+ raise ValueError(f"Failed to load model. Please ensure it is a valid PyTorch model file: {e}")
32
+
33
+ # 音声処理関数
34
+ def process_audio(audio_filepath, model_path):
35
+ try:
36
+ # モデルをロード
37
+ model = load_model(model_path)
38
+
39
+ # 入力音声のテンソル化
40
+ waveform, sample_rate = torchaudio.load(audio_filepath)
41
+ print(f"Loaded audio with shape {waveform.shape} and sample rate {sample_rate}")
42
+
43
+ # モデルに音声を入力し処理
44
+ with torch.no_grad():
45
+ processed_waveform = model(waveform)
46
+
47
+ # 処理結果の確認
48
+ if processed_waveform is None or processed_waveform.shape[1] == 0:
49
+ raise ValueError("Model returned empty waveform")
50
+
51
+ # 出力を保存
52
+ output_path = "processed_audio.wav"
53
+ torchaudio.save(output_path, processed_waveform, sample_rate)
54
+ print(f"Processed audio saved to {output_path}")
55
+
56
+ return output_path
57
+ except Exception as e:
58
+ print(f"Error: {str(e)}")
59
+ return f"Error: {str(e)}"
60
+
61
+ # Gradioインターフェース
62
+ def create_interface():
63
+ model_url = "https://huggingface.co/spaces/adhisetiawan/anime-voice-generator/raw/main/pretrained_models/alice/alice.pth"
64
+ model_path = "alice.pth" # ローカルに保存するモデルファイル名
65
+
66
+ # モデルをダウンロード
67
+ download_model(model_url, model_path)
68
+
69
+ # Gradioインターフェース
70
+ interface = gr.Interface(
71
+ fn=lambda audio_filepath: process_audio(audio_filepath, model_path),
72
+ inputs=gr.Audio(type="filepath", label="Source Audio"), # 修正ポイント: type="filepath"
73
+ outputs=gr.Audio(type="filepath", label="Processed Audio"), # 修正ポイント: type="filepath"
74
+ title="Anime Voice Filter",
75
+ description="指定されたモデルを使用して音声にフィルターをかけます。"
76
  )
77
+
78
+ return interface
79
 
80
+ if __name__ == "__main__":
81
+ interface = create_interface()
82
+ interface.launch()