gyrojeff commited on
Commit
9ea0468
·
1 Parent(s): 938cce3

feat: add tensor core settings in cli

Browse files
Files changed (1) hide show
  1. train.py +10 -2
train.py CHANGED
@@ -9,8 +9,6 @@ from detector.model import *
9
  from utils import get_current_tag
10
 
11
 
12
- torch.set_float32_matmul_precision("high")
13
-
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument(
16
  "-d",
@@ -97,9 +95,19 @@ parser.add_argument(
97
  default=512,
98
  help="Model feature image input size (default: 512)",
99
  )
 
 
 
 
 
 
 
 
100
 
101
  args = parser.parse_args()
102
 
 
 
103
  devices = args.devices
104
  single_batch_size = args.single_batch_size
105
 
 
9
  from utils import get_current_tag
10
 
11
 
 
 
12
  parser = argparse.ArgumentParser()
13
  parser.add_argument(
14
  "-d",
 
95
  default=512,
96
  help="Model feature image input size (default: 512)",
97
  )
98
+ parser.add_argument(
99
+ "-t",
100
+ "--tensor-core",
101
+ type=str,
102
+ choices=["medium", "high", "heighest"],
103
+ default="high",
104
+ help="Tensor core precision (default: high)",
105
+ )
106
 
107
  args = parser.parse_args()
108
 
109
+ torch.set_float32_matmul_precision(args.tensor_core)
110
+
111
  devices = args.devices
112
  single_batch_size = args.single_batch_size
113