Update convert.py
Browse files- convert.py +40 -5
convert.py
CHANGED
@@ -161,8 +161,11 @@ def check_final_model(model_id: str, folder: str):
|
|
161 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
162 |
config = AutoConfig.from_pretrained(folder)
|
163 |
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
166 |
|
167 |
if pt_infos != sf_infos:
|
168 |
error_string = create_diff(pt_infos, sf_infos)
|
@@ -199,7 +202,19 @@ def check_final_model(model_id: str, folder: str):
|
|
199 |
sf_model = sf_model.cuda()
|
200 |
kwargs = {k: v.cuda() for k, v in kwargs.items()}
|
201 |
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
sf_logits = sf_model(**kwargs)[0]
|
204 |
|
205 |
torch.testing.assert_close(sf_logits, pt_logits)
|
@@ -246,7 +261,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> Conversi
|
|
246 |
return operations, errors
|
247 |
|
248 |
|
249 |
-
def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List["Exception"]]:
|
250 |
pr_title = "Adding `safetensors` variant of this model"
|
251 |
info = api.model_info(model_id)
|
252 |
filenames = set(s.rfilename for s in info.siblings)
|
@@ -328,6 +343,26 @@ if __name__ == "__main__":
|
|
328 |
" Continue [Y/n] ?"
|
329 |
)
|
330 |
if txt.lower() in {"", "y"}:
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
else:
|
333 |
print(f"Answer was `{txt}` aborting.")
|
|
|
161 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
162 |
config = AutoConfig.from_pretrained(folder)
|
163 |
|
164 |
+
import transformers
|
165 |
+
|
166 |
+
class_ = getattr(transformers, config.architectures[0])
|
167 |
+
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
168 |
+
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
169 |
|
170 |
if pt_infos != sf_infos:
|
171 |
error_string = create_diff(pt_infos, sf_infos)
|
|
|
202 |
sf_model = sf_model.cuda()
|
203 |
kwargs = {k: v.cuda() for k, v in kwargs.items()}
|
204 |
|
205 |
+
try:
|
206 |
+
pt_logits = pt_model(**kwargs)[0]
|
207 |
+
except Exception as e:
|
208 |
+
try:
|
209 |
+
# Musicgen special exception.
|
210 |
+
decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long)
|
211 |
+
if torch.cuda.is_available():
|
212 |
+
decoder_input_ids = decoder_input_ids.cuda()
|
213 |
+
|
214 |
+
kwargs["decoder_input_ids"] = decoder_input_ids
|
215 |
+
pt_logits = pt_model(**kwargs)[0]
|
216 |
+
except Exception:
|
217 |
+
raise e
|
218 |
sf_logits = sf_model(**kwargs)[0]
|
219 |
|
220 |
torch.testing.assert_close(sf_logits, pt_logits)
|
|
|
261 |
return operations, errors
|
262 |
|
263 |
|
264 |
+
def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
|
265 |
pr_title = "Adding `safetensors` variant of this model"
|
266 |
info = api.model_info(model_id)
|
267 |
filenames = set(s.rfilename for s in info.siblings)
|
|
|
343 |
" Continue [Y/n] ?"
|
344 |
)
|
345 |
if txt.lower() in {"", "y"}:
|
346 |
+
try:
|
347 |
+
commit_info, errors = convert(api, model_id, force=args.force)
|
348 |
+
string = f"""
|
349 |
+
### Success 🔥
|
350 |
+
Yay! This model was successfully converted and a PR was open using your token, here:
|
351 |
+
[{commit_info.pr_url}]({commit_info.pr_url})
|
352 |
+
"""
|
353 |
+
if errors:
|
354 |
+
string += "\nErrors during conversion:\n"
|
355 |
+
string += "\n".join(
|
356 |
+
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
|
357 |
+
)
|
358 |
+
print(string)
|
359 |
+
except Exception as e:
|
360 |
+
print(
|
361 |
+
f"""
|
362 |
+
### Error 😢😢😢
|
363 |
+
|
364 |
+
{e}
|
365 |
+
"""
|
366 |
+
)
|
367 |
else:
|
368 |
print(f"Answer was `{txt}` aborting.")
|