Narsil HF staff commited on
Commit
53a69cd
·
1 Parent(s): ca6ec81

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +5 -6
convert.py CHANGED
@@ -13,7 +13,6 @@ from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, h
13
  from huggingface_hub.file_download import repo_folder_name
14
  from safetensors.torch import load_file, save_file
15
  from transformers import AutoConfig
16
- from transformers.pipelines.base import infer_framework_load_model
17
 
18
 
19
  COMMIT_DESCRIPTION = """
@@ -72,14 +71,14 @@ def rename(pt_filename: str) -> str:
72
 
73
 
74
  def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
75
- filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
76
  with open(filename, "r") as f:
77
  data = json.load(f)
78
 
79
  filenames = set(data["weight_map"].values())
80
  local_filenames = []
81
  for filename in filenames:
82
- pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token)
83
 
84
  sf_filename = rename(pt_filename)
85
  sf_filename = os.path.join(folder, sf_filename)
@@ -103,7 +102,7 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio
103
 
104
 
105
  def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
106
- pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
107
 
108
  sf_name = "model.safetensors"
109
  sf_filename = os.path.join(folder, sf_name)
@@ -157,7 +156,7 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
157
 
158
 
159
  def check_final_model(model_id: str, folder: str, token: Optional[str]):
160
- config = hf_hub_download(repo_id=model_id, filename="config.json", token=token)
161
  shutil.copy(config, os.path.join(folder, "config.json"))
162
  config = AutoConfig.from_pretrained(folder)
163
 
@@ -244,7 +243,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Opti
244
  for filename in filenames:
245
  prefix, ext = os.path.splitext(filename)
246
  if ext in extensions:
247
- pt_filename = hf_hub_download(model_id, filename=filename, token=token)
248
  dirname, raw_filename = os.path.split(filename)
249
  if raw_filename == "pytorch_model.bin":
250
  # XXX: This is a special case to handle `transformers` and the
 
13
  from huggingface_hub.file_download import repo_folder_name
14
  from safetensors.torch import load_file, save_file
15
  from transformers import AutoConfig
 
16
 
17
 
18
  COMMIT_DESCRIPTION = """
 
71
 
72
 
73
  def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
74
+ filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
75
  with open(filename, "r") as f:
76
  data = json.load(f)
77
 
78
  filenames = set(data["weight_map"].values())
79
  local_filenames = []
80
  for filename in filenames:
81
+ pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token, cache_dir=folder)
82
 
83
  sf_filename = rename(pt_filename)
84
  sf_filename = os.path.join(folder, sf_filename)
 
102
 
103
 
104
  def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
105
+ pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)
106
 
107
  sf_name = "model.safetensors"
108
  sf_filename = os.path.join(folder, sf_name)
 
156
 
157
 
158
  def check_final_model(model_id: str, folder: str, token: Optional[str]):
159
+ config = hf_hub_download(repo_id=model_id, filename="config.json", token=token, cache_dir=folder)
160
  shutil.copy(config, os.path.join(folder, "config.json"))
161
  config = AutoConfig.from_pretrained(folder)
162
 
 
243
  for filename in filenames:
244
  prefix, ext = os.path.splitext(filename)
245
  if ext in extensions:
246
+ pt_filename = hf_hub_download(model_id, filename=filename, token=token, cache_dir=folder)
247
  dirname, raw_filename = os.path.split(filename)
248
  if raw_filename == "pytorch_model.bin":
249
  # XXX: This is a special case to handle `transformers` and the