SFconvertbot commited on
Commit
0d3afce
·
1 Parent(s): 6f2f71c

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +41 -19
convert.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
 
11
  from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
12
  from huggingface_hub.file_download import repo_folder_name
13
- from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
14
 
15
 
16
  COMMIT_DESCRIPTION = """
@@ -32,6 +32,7 @@ Feel free to ignore this PR.
32
 
33
  ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
34
 
 
35
  def _remove_duplicate_names(
36
  state_dict: Dict[str, torch.Tensor],
37
  *,
@@ -48,9 +49,7 @@ def _remove_duplicate_names(
48
  shareds = _find_shared_tensors(state_dict)
49
  to_remove = defaultdict(list)
50
  for shared in shareds:
51
- complete_names = set(
52
- [name for name in shared if _is_complete(state_dict[name])]
53
- )
54
  if not complete_names:
55
  if len(shared) == 1:
56
  # Force contiguous
@@ -81,11 +80,13 @@ def _remove_duplicate_names(
81
  to_remove[keep_name].append(name)
82
  return to_remove
83
 
 
84
  def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
85
  try:
86
- import transformers
87
  import json
88
 
 
 
89
  config_filename = hf_hub_download(
90
  model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
91
  )
@@ -98,10 +99,11 @@ def get_discard_names(model_id: str, revision: Optional[str], folder: str, token
98
  # Name for this varible depends on transformers version.
99
  discard_names = getattr(class_, "_tied_weights_keys", [])
100
 
101
- except Exception as e:
102
  discard_names = []
103
  return discard_names
104
 
 
105
  class AlreadyExists(Exception):
106
  pass
107
 
@@ -126,8 +128,12 @@ def rename(pt_filename: str) -> str:
126
  return local
127
 
128
 
129
- def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
130
- filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
 
 
 
 
131
  with open(filename, "r") as f:
132
  data = json.load(f)
133
 
@@ -157,8 +163,12 @@ def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token:
157
  return operations, errors
158
 
159
 
160
- def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
161
- pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)
 
 
 
 
162
 
163
  sf_name = "model.safetensors"
164
  sf_filename = os.path.join(folder, sf_name)
@@ -217,20 +227,22 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
217
 
218
  def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
219
  try:
220
- main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
221
- discussions = api.get_repo_discussions(repo_id=model_id, revision=revision)
222
  except Exception:
223
  return None
224
  for discussion in discussions:
225
  if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
226
  commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
227
 
228
- if main_commit == commits[1].commit_id:
229
  return discussion
230
  return None
231
 
232
 
233
- def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
 
 
234
  operations = []
235
  errors = []
236
 
@@ -238,7 +250,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen
238
  for filename in filenames:
239
  prefix, ext = os.path.splitext(filename)
240
  if ext in extensions:
241
- pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder)
 
 
242
  dirname, raw_filename = os.path.split(filename)
243
  if raw_filename == "pytorch_model.bin":
244
  # XXX: This is a special case to handle `transformers` and the
@@ -255,7 +269,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen
255
  return operations, errors
256
 
257
 
258
- def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
 
 
259
  pr_title = "Adding `safetensors` variant of this model"
260
  info = api.model_info(model_id, revision=revision)
261
  filenames = set(s.rfilename for s in info.siblings)
@@ -279,13 +295,19 @@ def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force:
279
 
280
  discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
281
  if "pytorch_model.bin" in filenames:
282
- operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
 
 
283
  elif "pytorch_model.bin.index.json" in filenames:
284
- operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
 
 
285
  else:
286
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
287
  else:
288
- operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token)
 
 
289
 
290
  if operations:
291
  new_pr = api.create_commit(
 
10
 
11
  from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
12
  from huggingface_hub.file_download import repo_folder_name
13
+ from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
14
 
15
 
16
  COMMIT_DESCRIPTION = """
 
32
 
33
  ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
34
 
35
+
36
  def _remove_duplicate_names(
37
  state_dict: Dict[str, torch.Tensor],
38
  *,
 
49
  shareds = _find_shared_tensors(state_dict)
50
  to_remove = defaultdict(list)
51
  for shared in shareds:
52
+ complete_names = set([name for name in shared if _is_complete(state_dict[name])])
 
 
53
  if not complete_names:
54
  if len(shared) == 1:
55
  # Force contiguous
 
80
  to_remove[keep_name].append(name)
81
  return to_remove
82
 
83
+
84
  def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
85
  try:
 
86
  import json
87
 
88
+ import transformers
89
+
90
  config_filename = hf_hub_download(
91
  model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
92
  )
 
99
  # Name for this varible depends on transformers version.
100
  discard_names = getattr(class_, "_tied_weights_keys", [])
101
 
102
+ except Exception:
103
  discard_names = []
104
  return discard_names
105
 
106
+
107
  class AlreadyExists(Exception):
108
  pass
109
 
 
128
  return local
129
 
130
 
131
+ def convert_multi(
132
+ model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]
133
+ ) -> ConversionResult:
134
+ filename = hf_hub_download(
135
+ repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
136
+ )
137
  with open(filename, "r") as f:
138
  data = json.load(f)
139
 
 
163
  return operations, errors
164
 
165
 
166
+ def convert_single(
167
+ model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
168
+ ) -> ConversionResult:
169
+ pt_filename = hf_hub_download(
170
+ repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder
171
+ )
172
 
173
  sf_name = "model.safetensors"
174
  sf_filename = os.path.join(folder, sf_name)
 
227
 
228
  def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
229
  try:
230
+ revision_commit = api.model_info(model_id, revision=revision).sha
231
+ discussions = api.get_repo_discussions(repo_id=model_id)
232
  except Exception:
233
  return None
234
  for discussion in discussions:
235
  if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
236
  commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
237
 
238
+ if revision_commit == commits[1].commit_id:
239
  return discussion
240
  return None
241
 
242
 
243
+ def convert_generic(
244
+ model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
245
+ ) -> ConversionResult:
246
  operations = []
247
  errors = []
248
 
 
250
  for filename in filenames:
251
  prefix, ext = os.path.splitext(filename)
252
  if ext in extensions:
253
+ pt_filename = hf_hub_download(
254
+ model_id, revision=revision, filename=filename, token=token, cache_dir=folder
255
+ )
256
  dirname, raw_filename = os.path.split(filename)
257
  if raw_filename == "pytorch_model.bin":
258
  # XXX: This is a special case to handle `transformers` and the
 
269
  return operations, errors
270
 
271
 
272
+ def convert(
273
+ api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
274
+ ) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
275
  pr_title = "Adding `safetensors` variant of this model"
276
  info = api.model_info(model_id, revision=revision)
277
  filenames = set(s.rfilename for s in info.siblings)
 
295
 
296
  discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
297
  if "pytorch_model.bin" in filenames:
298
+ operations, errors = convert_single(
299
+ model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
300
+ )
301
  elif "pytorch_model.bin.index.json" in filenames:
302
+ operations, errors = convert_multi(
303
+ model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
304
+ )
305
  else:
306
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
307
  else:
308
+ operations, errors = convert_generic(
309
+ model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
310
+ )
311
 
312
  if operations:
313
  new_pr = api.create_commit(