tools / convert_flax_to_pt.py
patrickvonplaten's picture
uP
d7c590b
raw
history blame
2.1 kB
import argparse
import json
import os
import shutil
import torch
from tempfile import TemporaryDirectory
from typing import List, Optional
from diffusers import StableDiffusionPipeline, ControlNetModel
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
info = api.model_info(model_id)
filenames = set(s.rfilename for s in info.siblings)
is_sd = "model_index.json" in filenames
if is_sd:
model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True)
else:
model = ControlNetModel.from_pretrained(model_id, from_flax=True)
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
model.save_pretrained(folder)
model.save_pretrained(folder, safe_serialization=True)
if is_sd:
model.to(torch_dtype=torch.float16)
else:
model.half()
model.save_pretrained(folder, variant="fp16")
model.save_pretrained(folder, safe_serialization=True, variant="fp16")
api.upload_folder(
folder_path=folder,
repo_id=model_id,
repo_type="model",
create_pr=True,
)
print(model_id)
if __name__ == "__main__":
DESCRIPTION = """
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
It is PyTorch exclusive for now.
It works by downloading the weights (PT), converting them locally, and uploading them back
as a PR on the hub.
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"model_id",
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
args = parser.parse_args()
model_id = args.model_id
api = HfApi()
convert(api, model_id)