File size: 3,085 Bytes
7038078
 
 
e8c63d2
7038078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c63d2
7038078
 
e8c63d2
7038078
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
import requests
import gradio as gr
import torchaudio

# モデルのダウンロード関数
def download_model(model_url, output_path):
    if not os.path.exists(output_path):
        print(f"Downloading model from {model_url} to {output_path}")
        response = requests.get(model_url, stream=True)
        if response.status_code == 200:
            with open(output_path, 'wb') as f:
                f.write(response.content)
            print("Model downloaded successfully.")
        else:
            raise ValueError(f"Failed to download model: {response.status_code}")
    else:
        print(f"Model already exists at {output_path}")

# モデルロード用の関数
def load_model(model_path):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")
    print(f"Loading model from {model_path}")
    try:
        model = torch.load(model_path, map_location=torch.device('cpu'))
        model.eval()
        return model
    except Exception as e:
        raise ValueError(f"Failed to load model. Please ensure it is a valid PyTorch model file: {e}")

# 音声処理関数
def process_audio(audio_filepath, model_path):
    try:
        # モデルをロード
        model = load_model(model_path)
        
        # 入力音声のテンソル化
        waveform, sample_rate = torchaudio.load(audio_filepath)
        print(f"Loaded audio with shape {waveform.shape} and sample rate {sample_rate}")

        # モデルに音声を入力し処理
        with torch.no_grad():
            processed_waveform = model(waveform)
        
        # 処理結果の確認
        if processed_waveform is None or processed_waveform.shape[1] == 0:
            raise ValueError("Model returned empty waveform")

        # 出力を保存
        output_path = "processed_audio.wav"
        torchaudio.save(output_path, processed_waveform, sample_rate)
        print(f"Processed audio saved to {output_path}")

        return output_path
    except Exception as e:
        print(f"Error: {str(e)}")
        return f"Error: {str(e)}"

# Gradioインターフェース
def create_interface():
    model_url = "https://huggingface.co/spaces/adhisetiawan/anime-voice-generator/raw/main/pretrained_models/alice/alice.pth"
    model_path = "alice.pth"  # ローカルに保存するモデルファイル名

    # モデルをダウンロード
    download_model(model_url, model_path)
    
    # Gradioインターフェース
    interface = gr.Interface(
        fn=lambda audio_filepath: process_audio(audio_filepath, model_path),
        inputs=gr.Audio(type="filepath", label="Source Audio"),  # 修正ポイント: type="filepath"
        outputs=gr.Audio(type="filepath", label="Processed Audio"),  # 修正ポイント: type="filepath"
        title="Anime Voice Filter",
        description="指定されたモデルを使用して音声にフィルターをかけます。"
    )
    
    return interface

if __name__ == "__main__":
    interface = create_interface()
    interface.launch()