chiragjn winglian commited on
Commit
dde02fc
·
unverified ·
1 Parent(s): b9bb169

Pass weakref to model in the SIGINT handler to free up model post train function (#1581)

Browse files

* Pass weakref to model in the SIGINT handler to free up model post train()

* Fix lint issues

* chore: lint

---------

Co-authored-by: Wing Lian <[email protected]>

Files changed (1) hide show
  1. src/axolotl/train.py +12 -5
src/axolotl/train.py CHANGED
@@ -3,6 +3,7 @@
3
  import os
4
  import signal
5
  import sys
 
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  from typing import Optional, Tuple, Union
@@ -127,14 +128,20 @@ def train(
127
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
128
  if cfg.local_rank == 0:
129
 
130
- def terminate_handler(_, __, model):
131
- if cfg.flash_optimum and BetterTransformer:
132
- model = BetterTransformer.reverse(model)
133
- model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
 
 
 
 
134
  sys.exit(0)
135
 
 
136
  signal.signal(
137
- signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
 
138
  )
139
 
140
  badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
 
3
  import os
4
  import signal
5
  import sys
6
+ import weakref
7
  from dataclasses import dataclass
8
  from pathlib import Path
9
  from typing import Optional, Tuple, Union
 
128
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
129
  if cfg.local_rank == 0:
130
 
131
+ def terminate_handler(_, __, model_weakref):
132
+ if model_weakref() is not None:
133
+ _model = model_weakref()
134
+ if cfg.flash_optimum and BetterTransformer:
135
+ _model = BetterTransformer.reverse(_model)
136
+ _model.save_pretrained(
137
+ cfg.output_dir, safe_serialization=safe_serialization
138
+ )
139
  sys.exit(0)
140
 
141
+ _model_weakref = weakref.ref(model)
142
  signal.signal(
143
+ signal.SIGINT,
144
+ lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
145
  )
146
 
147
  badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""