Spaces:
Running
Running
Update convert.py
Browse filesKeeping one part of the final test still retaining the no RAM usage feature.
- convert.py +12 -8
convert.py
CHANGED
@@ -163,13 +163,17 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
163 |
import transformers
|
164 |
|
165 |
class_ = getattr(transformers, config.architectures[0])
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
pt_params = pt_model.state_dict()
|
174 |
sf_params = sf_model.state_dict()
|
175 |
|
@@ -291,7 +295,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
|
|
291 |
operations, errors = convert_multi(model_id, folder, token=api.token)
|
292 |
else:
|
293 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
294 |
-
|
295 |
else:
|
296 |
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
|
297 |
|
|
|
163 |
import transformers
|
164 |
|
165 |
class_ = getattr(transformers, config.architectures[0])
|
166 |
+
with torch.device("meta"):
|
167 |
+
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
168 |
+
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
169 |
+
|
170 |
+
if pt_infos != sf_infos:
|
171 |
+
error_string = create_diff(pt_infos, sf_infos)
|
172 |
+
raise ValueError(f"Different infos when reloading the model: {error_string}")
|
173 |
+
|
174 |
+
#### XXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
175 |
+
#### SKIPPING THE REST OF THE test to save RAM
|
176 |
+
return
|
177 |
pt_params = pt_model.state_dict()
|
178 |
sf_params = sf_model.state_dict()
|
179 |
|
|
|
295 |
operations, errors = convert_multi(model_id, folder, token=api.token)
|
296 |
else:
|
297 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
298 |
+
check_final_model(model_id, folder, token=api.token)
|
299 |
else:
|
300 |
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
|
301 |
|