Update convert.py
Browse files- convert.py +9 -9
convert.py
CHANGED
@@ -71,8 +71,8 @@ def rename(pt_filename: str) -> str:
|
|
71 |
return local
|
72 |
|
73 |
|
74 |
-
def convert_multi(model_id: str, folder: str) -> ConversionResult:
|
75 |
-
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
|
76 |
with open(filename, "r") as f:
|
77 |
data = json.load(f)
|
78 |
|
@@ -102,8 +102,8 @@ def convert_multi(model_id: str, folder: str) -> ConversionResult:
|
|
102 |
return operations, errors
|
103 |
|
104 |
|
105 |
-
def convert_single(model_id: str, folder: str) -> ConversionResult:
|
106 |
-
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
|
107 |
|
108 |
sf_name = "model.safetensors"
|
109 |
sf_filename = os.path.join(folder, sf_name)
|
@@ -236,7 +236,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
|
|
236 |
return None
|
237 |
|
238 |
|
239 |
-
def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> ConversionResult:
|
240 |
operations = []
|
241 |
errors = []
|
242 |
|
@@ -244,7 +244,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> Conversi
|
|
244 |
for filename in filenames:
|
245 |
prefix, ext = os.path.splitext(filename)
|
246 |
if ext in extensions:
|
247 |
-
pt_filename = hf_hub_download(model_id, filename=filename)
|
248 |
dirname, raw_filename = os.path.split(filename)
|
249 |
if raw_filename == "pytorch_model.bin":
|
250 |
# XXX: This is a special case to handle `transformers` and the
|
@@ -283,14 +283,14 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
|
|
283 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
284 |
elif library_name == "transformers":
|
285 |
if "pytorch_model.bin" in filenames:
|
286 |
-
operations, errors = convert_single(model_id, folder)
|
287 |
elif "pytorch_model.bin.index.json" in filenames:
|
288 |
-
operations, errors = convert_multi(model_id, folder)
|
289 |
else:
|
290 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
291 |
check_final_model(model_id, folder)
|
292 |
else:
|
293 |
-
operations, errors = convert_generic(model_id, folder, filenames)
|
294 |
|
295 |
if operations:
|
296 |
new_pr = api.create_commit(
|
|
|
71 |
return local
|
72 |
|
73 |
|
74 |
+
def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
|
75 |
+
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
|
76 |
with open(filename, "r") as f:
|
77 |
data = json.load(f)
|
78 |
|
|
|
102 |
return operations, errors
|
103 |
|
104 |
|
105 |
+
def convert_single(model_id: str, folder: str, token: str) -> ConversionResult:
|
106 |
+
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
|
107 |
|
108 |
sf_name = "model.safetensors"
|
109 |
sf_filename = os.path.join(folder, sf_name)
|
|
|
236 |
return None
|
237 |
|
238 |
|
239 |
+
def convert_generic(model_id: str, folder: str, filenames: Set[str], token: str) -> ConversionResult:
|
240 |
operations = []
|
241 |
errors = []
|
242 |
|
|
|
244 |
for filename in filenames:
|
245 |
prefix, ext = os.path.splitext(filename)
|
246 |
if ext in extensions:
|
247 |
+
pt_filename = hf_hub_download(model_id, filename=filename, token=token)
|
248 |
dirname, raw_filename = os.path.split(filename)
|
249 |
if raw_filename == "pytorch_model.bin":
|
250 |
# XXX: This is a special case to handle `transformers` and the
|
|
|
283 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
284 |
elif library_name == "transformers":
|
285 |
if "pytorch_model.bin" in filenames:
|
286 |
+
operations, errors = convert_single(model_id, folder, token=api.token)
|
287 |
elif "pytorch_model.bin.index.json" in filenames:
|
288 |
+
operations, errors = convert_multi(model_id, folder, token=api.token)
|
289 |
else:
|
290 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
291 |
check_final_model(model_id, folder)
|
292 |
else:
|
293 |
+
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
|
294 |
|
295 |
if operations:
|
296 |
new_pr = api.create_commit(
|