import torch from huggingface_hub import hf_hub_download from torchvision import transforms from PIL import Image import gradio as gr # Import generator dari Pix2Pix dependencies from models.networks import define_G # Paksa penggunaan CPU torch.cuda.is_available = lambda: False # Fungsi untuk mengunduh model dari Hugging Face def download_model(): print("Mengunduh model dari Hugging Face Hub...") repo_id = "Matharrr/pix2pix-coloring-manga-model" model_filename = "latest_net_G.pth" model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) print(f"Model berhasil diunduh ke: {model_path}") return model_path # Pix2Pix Model Loader class Pix2PixModel: def __init__(self): # Path model yang diunduh self.model_path = download_model() self.device = torch.device("cpu") # Paksa ke CPU # Definisikan arsitektur model generator print("Membuat arsitektur model generator...") self.netG = define_G( input_nc=1, # Input grayscale 1 channel (HARUS SAMA DENGAN STATE_DICT) output_nc=3, # Output RGB 3 channel ngf=64, netG="unet_256", norm="batch", use_dropout=False, init_type="normal", init_gain=0.02, gpu_ids=[] ) self.netG.to(self.device) # Load state_dict print("Memuat model dari state_dict...") state_dict = torch.load(self.model_path, map_location=self.device) self.netG.load_state_dict(state_dict) # strict=True karena input_nc sama self.netG.eval() print("Model berhasil dimuat dan siap untuk inference.") # Transformasi input self.transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), # Konversi ke grayscale 1 channel transforms.Resize((256, 256)), # Resize ke 256x256 transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # Normalisasi grayscale ]) def predict(self, image): # Simpan ukuran asli original_size = image.size # Preprocess input (grayscale 1 channel) image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): output = self.netG(image_tensor) # Postprocess: Resize ke ukuran asli output_image = (output.squeeze().cpu().clamp(-1, 1) + 1) / 2.0 output_image = transforms.ToPILImage()(output_image).resize(original_size, Image.BICUBIC) return output_image # Inisialisasi model print("Inisialisasi model Pix2Pix...") model = Pix2PixModel() # Fungsi untuk Gradio def colorize_image(input_image): print("Menerima input gambar...") return model.predict(input_image) # Gradio Interface interface = gr.Interface( fn=colorize_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Pix2Pix Image Translation Model", description="Upload gambar input grayscale untuk dikonversi menjadi gambar berwarna." ) if __name__ == "__main__": print("Meluncurkan antarmuka Gradio...") interface.launch()