pszemraj commited on
Commit
64c4fde
·
1 Parent(s): aa7b890

End of training

Browse files
README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - generated_from_trainer
4
+ model-index:
5
+ - name: checkpoints
6
+ results: []
7
+ ---
8
+
9
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
10
+ should probably proofread and complete it, then remove this comment. -->
11
+
12
+ # checkpoints
13
+
14
+ This model is a fine-tuned version of [gpt2-medium](https://huggingface.co/gpt2-medium) on an unknown dataset.
15
+ It achieves the following results on the evaluation set:
16
+ - Loss: 4.3281
17
+
18
+ ## Model description
19
+
20
+ More information needed
21
+
22
+ ## Intended uses & limitations
23
+
24
+ More information needed
25
+
26
+ ## Training and evaluation data
27
+
28
+ More information needed
29
+
30
+ ## Training procedure
31
+
32
+ ### Training hyperparameters
33
+
34
+ The following hyperparameters were used during training:
35
+ - learning_rate: 2e-05
36
+ - train_batch_size: 32
37
+ - eval_batch_size: 32
38
+ - seed: 42
39
+ - distributed_type: multi-GPU
40
+ - gradient_accumulation_steps: 2
41
+ - total_train_batch_size: 64
42
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
43
+ - lr_scheduler_type: cosine
44
+ - lr_scheduler_warmup_ratio: 0.05
45
+ - num_epochs: 10
46
+
47
+ ### Training results
48
+
49
+ | Training Loss | Epoch | Step | Validation Loss |
50
+ |:-------------:|:-----:|:----:|:---------------:|
51
+ | 34.991 | 1.0 | 837 | 14.8359 |
52
+ | 12.2881 | 2.0 | 1674 | 9.375 |
53
+ | 8.5071 | 3.0 | 2511 | 7.2148 |
54
+ | 7.6031 | 4.0 | 3348 | 6.1758 |
55
+ | 6.4808 | 5.0 | 4185 | 5.5820 |
56
+ | 5.8562 | 6.0 | 5022 | 5.0977 |
57
+ | 5.6094 | 7.0 | 5859 | 4.8203 |
58
+ | 5.2591 | 8.0 | 6696 | 4.5977 |
59
+ | 5.0031 | 9.0 | 7533 | 4.4219 |
60
+ | 4.8837 | 10.0 | 8370 | 4.3281 |
61
+
62
+
63
+ ### Framework versions
64
+
65
+ - Transformers 4.16.1
66
+ - Pytorch 1.10.0+cu111
67
+ - Tokenizers 0.11.0
last-checkpoint/config.json DELETED
@@ -1,41 +0,0 @@
1
- {
2
- "_name_or_path": "gpt2-medium",
3
- "activation_function": "gelu_new",
4
- "architectures": [
5
- "GPT2LMHeadModel"
6
- ],
7
- "attn_pdrop": 0.1,
8
- "bos_token_id": 50256,
9
- "embd_pdrop": 0.1,
10
- "eos_token_id": 50256,
11
- "initializer_range": 0.02,
12
- "layer_norm_epsilon": 1e-05,
13
- "model_type": "gpt2",
14
- "n_ctx": 1024,
15
- "n_embd": 1024,
16
- "n_head": 16,
17
- "n_inner": null,
18
- "n_layer": 24,
19
- "n_positions": 1024,
20
- "n_special": 0,
21
- "predict_special_tokens": true,
22
- "reorder_and_upcast_attn": false,
23
- "resid_pdrop": 0.1,
24
- "scale_attn_by_inverse_layer_idx": false,
25
- "scale_attn_weights": true,
26
- "summary_activation": null,
27
- "summary_first_dropout": 0.1,
28
- "summary_proj_to_labels": true,
29
- "summary_type": "cls_index",
30
- "summary_use_proj": true,
31
- "task_specific_params": {
32
- "text-generation": {
33
- "do_sample": true,
34
- "max_length": 50
35
- }
36
- },
37
- "torch_dtype": "float16",
38
- "transformers_version": "4.16.1",
39
- "use_cache": false,
40
- "vocab_size": 50257
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/global_step837/mp_rank_00_model_states.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4d75032a46f06a12506ad626da354ac633e118cf229057d30761450f17522b19
3
- size 734881234
 
 
 
 
last-checkpoint/global_step837/zero_pp_rank_0_mp_rank_00_optim_states.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3fdb75c45bec8a89d37a7751cac9581b086eee4c475891ec4a9a1a76bc1b9301
3
- size 4257899299
 
 
 
 
last-checkpoint/latest DELETED
@@ -1 +0,0 @@
1
- global_step837
 
 
last-checkpoint/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec811d5bc7649cd93185ea7169ef8965afc472754ba3abbc097011ef40e6d13c
3
- size 734877906
 
 
 
 
last-checkpoint/rng_state_0.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e3f11cd80ce64332c081b7126ac7be203b9c47053a7ef285bec9fd4f2ed7739d
3
- size 14503
 
 
 
 
last-checkpoint/trainer_state.json DELETED
@@ -1,30 +0,0 @@
1
- {
2
- "best_metric": null,
3
- "best_model_checkpoint": null,
4
- "epoch": 1.0,
5
- "global_step": 837,
6
- "is_hyper_param_search": false,
7
- "is_local_process_zero": true,
8
- "is_world_process_zero": true,
9
- "log_history": [
10
- {
11
- "epoch": 0.6,
12
- "learning_rate": 2e-05,
13
- "loss": 34.991,
14
- "step": 500
15
- },
16
- {
17
- "epoch": 1.0,
18
- "eval_loss": 14.8359375,
19
- "eval_runtime": 48.5288,
20
- "eval_samples_per_second": 369.121,
21
- "eval_steps_per_second": 11.54,
22
- "step": 837
23
- }
24
- ],
25
- "max_steps": 8370,
26
- "num_train_epochs": 10,
27
- "total_flos": 1.2434141312188416e+16,
28
- "trial_name": null,
29
- "trial_params": null
30
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:303f966408d16a937e8e91f930a52f03957f1584e34b8f7b4d21b2d6c03951c8
3
- size 4143
 
 
 
 
last-checkpoint/zero_to_fp32.py DELETED
@@ -1,453 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
4
- # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
5
- # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
6
- # application.
7
- #
8
- # example: python zero_to_fp32.py . pytorch_model.bin
9
-
10
- import argparse
11
- import torch
12
- import glob
13
- import math
14
- import os
15
- from collections import OrderedDict
16
-
17
- # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
18
- # DeepSpeed data structures it has to be available in the current python environment.
19
- import deepspeed
20
- from deepspeed.utils import logger
21
-
22
- debug = 0
23
-
24
- # load to cpu
25
- device = torch.device('cpu')
26
-
27
-
28
- def get_model_state_file(checkpoint_dir, zero_stage):
29
- if not os.path.isdir(checkpoint_dir):
30
- raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
31
-
32
- # there should be only one file
33
- if zero_stage == 2:
34
- file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
35
- elif zero_stage == 3:
36
- file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
37
-
38
- if not os.path.exists(file):
39
- raise FileNotFoundError(f"can't find model states file at '{file}'")
40
-
41
- return file
42
-
43
-
44
- def get_optim_files(checkpoint_dir):
45
- # XXX: need to test that this simple glob rule works for multi-node setup too
46
- optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")))
47
-
48
- if len(optim_files) == 0:
49
- raise FileNotFoundError(
50
- f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
51
-
52
- return optim_files
53
-
54
-
55
- def parse_model_state(file):
56
- state_dict = torch.load(file, map_location=device)
57
-
58
- if "buffer_names" not in state_dict:
59
- raise ValueError(f"{file} is not a model state checkpoint")
60
- buffer_names = state_dict["buffer_names"]
61
- if debug:
62
- print("Found buffers:", buffer_names)
63
-
64
- # recover just the buffers while restoring them to fp32 if they were saved in fp16
65
- buffers = {
66
- k: v.float()
67
- for k,
68
- v in state_dict["module"].items() if k in buffer_names
69
- }
70
- return buffers
71
-
72
-
73
- def parse_optim_states(files, ds_checkpoint_dir):
74
-
75
- total_files = len(files)
76
- state_dicts = []
77
- for f in files:
78
- state_dicts.append(torch.load(f, map_location=device))
79
-
80
- if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
81
- raise ValueError(f"{files[0]} is not a zero checkpoint")
82
- zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
83
- world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
84
- param_shapes = state_dicts[0]["param_shapes"]
85
- # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
86
- # parameters can be different from data parallelism for non-expert parameters. So we can just
87
- # use the max of the partition_count to get the dp world_size.
88
-
89
- if type(world_size) is list:
90
- world_size = max(world_size)
91
-
92
- if world_size != total_files:
93
- raise ValueError(
94
- f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
95
- "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
96
- )
97
-
98
- # the groups are named differently in each stage
99
- if zero_stage == 2:
100
- fp32_groups_key = "single_partition_of_fp32_groups"
101
- elif zero_stage == 3:
102
- fp32_groups_key = "fp32_flat_groups"
103
- else:
104
- raise ValueError(f"unknown zero stage {zero_stage}")
105
-
106
- if zero_stage == 2:
107
- fp32_flat_groups = [
108
- state_dicts[i]['optimizer_state_dict'][fp32_groups_key]
109
- for i in range(len(state_dicts))
110
- ]
111
- elif zero_stage == 3:
112
- # if there is more than one param group, there will be multiple flattened tensors - one
113
- # flattened tensor per group - for simplicity merge them into a single tensor
114
- #
115
- # XXX: could make the script more memory efficient for when there are multiple groups - it
116
- # will require matching the sub-lists of param_shapes for each param group flattened tensor
117
-
118
- fp32_flat_groups = [
119
- torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
120
- 0) for i in range(len(state_dicts))
121
- ]
122
-
123
- return zero_stage, world_size, param_shapes, fp32_flat_groups
124
-
125
-
126
- def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
127
- """
128
- Returns fp32 state_dict reconstructed from ds checkpoint
129
-
130
- Args:
131
- - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
132
-
133
- """
134
- print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
135
-
136
- optim_files = get_optim_files(ds_checkpoint_dir)
137
- zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
138
- print(
139
- f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
140
-
141
- model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
142
- buffers = parse_model_state(model_file)
143
-
144
- if zero_stage == 2:
145
- return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
146
- param_shapes,
147
- fp32_flat_groups,
148
- buffers)
149
- elif zero_stage == 3:
150
- return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
151
- param_shapes,
152
- fp32_flat_groups,
153
- buffers)
154
-
155
-
156
- def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
157
- param_shapes,
158
- fp32_flat_groups,
159
- buffers):
160
-
161
- # Reconstruction protocol:
162
- #
163
- # XXX: document this
164
-
165
- if debug:
166
- for i in range(world_size):
167
- for j in range(len(fp32_flat_groups[0])):
168
- print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
169
-
170
- # XXX: memory usage doubles here (zero2)
171
- num_param_groups = len(fp32_flat_groups[0])
172
- merged_single_partition_of_fp32_groups = []
173
- for i in range(num_param_groups):
174
- merged_partitions = [sd[i] for sd in fp32_flat_groups]
175
- full_single_fp32_vector = torch.cat(merged_partitions, 0)
176
- merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
177
- avail_numel = sum([
178
- full_single_fp32_vector.numel()
179
- for full_single_fp32_vector in merged_single_partition_of_fp32_groups
180
- ])
181
-
182
- if debug:
183
- wanted_params = sum([len(shapes) for shapes in param_shapes])
184
- wanted_numel = sum(
185
- [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
186
- # not asserting if there is a mismatch due to possible padding
187
- print(f"Have {avail_numel} numels to process.")
188
- print(f"Need {wanted_numel} numels in {wanted_params} params.")
189
-
190
- state_dict = OrderedDict()
191
-
192
- # buffers
193
- state_dict.update(buffers)
194
- if debug:
195
- print(f"added {len(buffers)} buffers")
196
-
197
- # params
198
- # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
199
- # out-of-core computing solution
200
- total_numel = 0
201
- total_params = 0
202
- for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
203
- offset = 0
204
- avail_numel = full_single_fp32_vector.numel()
205
- for name, shape in shapes.items():
206
-
207
- unpartitioned_numel = shape.numel()
208
- total_numel += unpartitioned_numel
209
- total_params += 1
210
-
211
- if debug:
212
- print(
213
- f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
214
- )
215
- state_dict[name] = full_single_fp32_vector.narrow(
216
- 0,
217
- offset,
218
- unpartitioned_numel).view(shape)
219
- offset += unpartitioned_numel
220
-
221
- # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
222
- # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
223
- # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
224
- # live optimizer object, so we are checking that the numbers are within the right range
225
- align_to = 2 * world_size
226
-
227
- def zero2_align(x):
228
- return align_to * math.ceil(x / align_to)
229
-
230
- if debug:
231
- print(f"original offset={offset}, avail_numel={avail_numel}")
232
-
233
- offset = zero2_align(offset)
234
- avail_numel = zero2_align(avail_numel)
235
-
236
- if debug:
237
- print(f"aligned offset={offset}, avail_numel={avail_numel}")
238
-
239
- # Sanity check
240
- if offset != avail_numel:
241
- raise ValueError(
242
- f"consumed {offset} numels out of {avail_numel} - something is wrong")
243
-
244
- print(
245
- f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
246
- )
247
-
248
- return state_dict
249
-
250
-
251
- def zero3_partitioned_param_info(unpartitioned_numel, world_size):
252
- remainder = unpartitioned_numel % world_size
253
- padding_numel = (world_size - remainder) if remainder else 0
254
- partitioned_numel = math.ceil(unpartitioned_numel / world_size)
255
- return partitioned_numel, padding_numel
256
-
257
-
258
- def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
259
- param_shapes,
260
- fp32_flat_groups,
261
- buffers):
262
-
263
- # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
264
- # param, re-consolidating each param, while dealing with padding if any
265
-
266
- avail_numel = fp32_flat_groups[0].numel() * world_size
267
- # merge list of dicts, preserving order
268
- param_shapes = {k: v for d in param_shapes for k, v in d.items()}
269
-
270
- if debug:
271
- for i in range(world_size):
272
- print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}")
273
-
274
- wanted_params = len(param_shapes)
275
- wanted_numel = sum(shape.numel() for shape in param_shapes.values())
276
- # not asserting if there is a mismatch due to possible padding
277
- print(f"Have {avail_numel} numels to process.")
278
- print(f"Need {wanted_numel} numels in {wanted_params} params.")
279
-
280
- state_dict = OrderedDict()
281
-
282
- # buffers
283
- state_dict.update(buffers)
284
- if debug:
285
- print(f"added {len(buffers)} buffers")
286
-
287
- # params
288
- # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
289
- # out-of-core computing solution
290
- offset = 0
291
- total_numel = 0
292
- total_params = 0
293
- for name, shape in param_shapes.items():
294
-
295
- unpartitioned_numel = shape.numel()
296
- total_numel += unpartitioned_numel
297
- total_params += 1
298
-
299
- partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
300
-
301
- if debug:
302
- print(
303
- f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
304
- )
305
-
306
- # XXX: memory usage doubles here
307
- state_dict[name] = torch.cat(
308
- tuple(fp32_flat_groups[i].narrow(0,
309
- offset,
310
- partitioned_numel)
311
- for i in range(world_size)),
312
- 0).narrow(0,
313
- 0,
314
- unpartitioned_numel).view(shape)
315
- offset += partitioned_numel
316
-
317
- offset *= world_size
318
-
319
- # Sanity check
320
- if offset != avail_numel:
321
- raise ValueError(
322
- f"consumed {offset} numels out of {avail_numel} - something is wrong")
323
-
324
- print(
325
- f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
326
- )
327
-
328
- return state_dict
329
-
330
-
331
- def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
332
- """
333
- Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
334
- ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
335
- via a model hub.
336
-
337
- Args:
338
- - ``checkpoint_dir``: path to the desired checkpoint folder
339
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
340
-
341
- Returns:
342
- - pytorch ``state_dict``
343
-
344
- Note: this approach may not work if your application doesn't have sufficient free CPU memory and
345
- you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
346
- the checkpoint.
347
-
348
- A typical usage might be ::
349
-
350
- from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
351
- # do the training and checkpoint saving
352
- state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
353
- model = model.cpu() # move to cpu
354
- model.load_state_dict(state_dict)
355
- # submit to model hub or save the model to share with others
356
-
357
- In this example the ``model`` will no longer be usable in the deepspeed context of the same
358
- application. i.e. you will need to re-initialize the deepspeed engine, since
359
- ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
360
-
361
- If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
362
-
363
- """
364
- if tag is None:
365
- latest_path = os.path.join(checkpoint_dir, 'latest')
366
- if os.path.isfile(latest_path):
367
- with open(latest_path, 'r') as fd:
368
- tag = fd.read().strip()
369
- else:
370
- raise ValueError(f"Unable to find 'latest' file at {latest_path}")
371
-
372
- ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
373
-
374
- if not os.path.isdir(ds_checkpoint_dir):
375
- raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
376
-
377
- return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
378
-
379
-
380
- def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
381
- """
382
- Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
383
- loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
384
-
385
- Args:
386
- - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
387
- - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
388
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
389
- """
390
-
391
- state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
392
- print(f"Saving fp32 state dict to {output_file}")
393
- torch.save(state_dict, output_file)
394
-
395
-
396
- def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
397
- """
398
- 1. Put the provided model to cpu
399
- 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
400
- 3. Load it into the provided model
401
-
402
- Args:
403
- - ``model``: the model object to update
404
- - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
405
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
406
-
407
- Returns:
408
- - ``model`: modified model
409
-
410
- Make sure you have plenty of CPU memory available before you call this function. If you don't
411
- have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
412
- conveniently placed for you in the checkpoint folder.
413
-
414
- A typical usage might be ::
415
-
416
- from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
417
- model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
418
- # submit to model hub or save the model to share with others
419
-
420
- Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
421
- of the same application. i.e. you will need to re-initialize the deepspeed engine, since
422
- ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
423
-
424
- """
425
- logger.info(f"Extracting fp32 weights")
426
- state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
427
-
428
- logger.info(f"Overwriting model with fp32 weights")
429
- model = model.cpu()
430
- model.load_state_dict(state_dict, strict=False)
431
-
432
- return model
433
-
434
-
435
- if __name__ == "__main__":
436
-
437
- parser = argparse.ArgumentParser()
438
- parser.add_argument(
439
- "checkpoint_dir",
440
- type=str,
441
- help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
442
- parser.add_argument(
443
- "output_file",
444
- type=str,
445
- help=
446
- "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
447
- )
448
- parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
449
- args = parser.parse_args()
450
-
451
- debug = args.debug
452
-
453
- convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec811d5bc7649cd93185ea7169ef8965afc472754ba3abbc097011ef40e6d13c
3
  size 734877906
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94ae348d639baab6315e0afee97d451e53a81ec71fd4eb04afc63408cc06cfa5
3
  size 734877906