import os import torch import requests import gradio as gr import torchaudio # モデルのダウンロード関数 def download_model(model_url, output_path): if not os.path.exists(output_path): print(f"Downloading model from {model_url} to {output_path}") response = requests.get(model_url, stream=True) if response.status_code == 200: with open(output_path, 'wb') as f: f.write(response.content) print("Model downloaded successfully.") else: raise ValueError(f"Failed to download model: {response.status_code}") else: print(f"Model already exists at {output_path}") # モデルロード用の関数 def load_model(model_path): if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") print(f"Loading model from {model_path}") try: model = torch.load(model_path, map_location=torch.device('cpu')) model.eval() return model except Exception as e: raise ValueError(f"Failed to load model. Please ensure it is a valid PyTorch model file: {e}") # 音声処理関数 def process_audio(audio_filepath, model_path): try: # モデルをロード model = load_model(model_path) # 入力音声のテンソル化 waveform, sample_rate = torchaudio.load(audio_filepath) print(f"Loaded audio with shape {waveform.shape} and sample rate {sample_rate}") # モデルに音声を入力し処理 with torch.no_grad(): processed_waveform = model(waveform) # 処理結果の確認 if processed_waveform is None or processed_waveform.shape[1] == 0: raise ValueError("Model returned empty waveform") # 出力を保存 output_path = "processed_audio.wav" torchaudio.save(output_path, processed_waveform, sample_rate) print(f"Processed audio saved to {output_path}") return output_path except Exception as e: print(f"Error: {str(e)}") return f"Error: {str(e)}" # Gradioインターフェース def create_interface(): model_url = "https://huggingface.co/spaces/adhisetiawan/anime-voice-generator/raw/main/pretrained_models/alice/alice.pth" model_path = "alice.pth" # ローカルに保存するモデルファイル名 # モデルをダウンロード download_model(model_url, model_path) # Gradioインターフェース interface = gr.Interface( fn=lambda audio_filepath: process_audio(audio_filepath, model_path), inputs=gr.Audio(type="filepath", label="Source Audio"), # 修正ポイント: type="filepath" outputs=gr.Audio(type="filepath", label="Processed Audio"), # 修正ポイント: type="filepath" title="Anime Voice Filter", description="指定されたモデルを使用して音声にフィルターをかけます。" ) return interface if __name__ == "__main__": interface = create_interface() interface.launch()