sachin commited on
Commit
69fda24
·
1 Parent(s): c6fe3c5

Uploading models to hub

Browse files
Files changed (2) hide show
  1. src/config.py +2 -0
  2. src/trainer.py +25 -0
src/config.py CHANGED
@@ -7,9 +7,11 @@ MAX_DOWNLOAD_TIME = 0.2
7
 
8
  IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
9
  WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
 
10
 
11
  IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
12
  WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
 
13
 
14
  MODEL_NAME = "tiny_clip"
15
 
 
7
 
8
  IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
9
  WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
10
+ MODEL_PATH = pathlib.Path("/tmp/models")
11
 
12
  IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
13
  WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
14
+ MODEL_PATH.mkdir(parents=True, exist_ok=True)
15
 
16
  MODEL_NAME = "tiny_clip"
17
 
src/trainer.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from src import config
2
  from src import data
3
  from src import loss
@@ -8,7 +10,28 @@ from src import utils
8
  from src.lightning_module import LightningModule
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def train(trainer_config: config.TrainerConfig):
 
 
12
  transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
13
  tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
14
  train_dl, valid_dl = data.get_dataset(
@@ -28,6 +51,8 @@ def train(trainer_config: config.TrainerConfig):
28
  trainer = utils.get_trainer(trainer_config)
29
  trainer.fit(lightning_module, train_dl, valid_dl)
30
 
 
 
31
 
32
  if __name__ == "__main__":
33
  trainer_config = config.TrainerConfig(debug=True)
 
1
+ import os
2
+
3
  from src import config
4
  from src import data
5
  from src import loss
 
10
  from src.lightning_module import LightningModule
11
 
12
 
13
+ def _upload_model_to_hub(
14
+ vision_encoder: models.TinyCLIPVisionEncoder, text_encoder: models.TinyCLIPTextEncoder
15
+ ):
16
+ vision_encoder.save_pretrained(
17
+ str(config.MODEL_PATH),
18
+ variant="vision_encoder",
19
+ safe_serialization=True,
20
+ push_to_hub=True,
21
+ repo_id="debug-clip-model",
22
+ )
23
+ text_encoder.save_pretrained(
24
+ str(config.MODEL_PATH),
25
+ variant="text_encoder",
26
+ safe_serialization=True,
27
+ push_to_hub=True,
28
+ repo_id="debug-clip-model",
29
+ )
30
+
31
+
32
  def train(trainer_config: config.TrainerConfig):
33
+ if "HF_TOKEN" not in os.environ:
34
+ raise ValueError("Please set the HF_TOKEN environment variable.")
35
  transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
36
  tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
37
  train_dl, valid_dl = data.get_dataset(
 
51
  trainer = utils.get_trainer(trainer_config)
52
  trainer.fit(lightning_module, train_dl, valid_dl)
53
 
54
+ _upload_model_to_hub(vision_encoder, text_encoder)
55
+
56
 
57
  if __name__ == "__main__":
58
  trainer_config = config.TrainerConfig(debug=True)