Niv Sardi commited on
Commit
6d7f6ee
·
1 Parent(s): 1778651

python/write_data: support yolo 5 and 6

Browse files
Files changed (1) hide show
  1. python/write_data.py +20 -5
python/write_data.py CHANGED
@@ -4,13 +4,26 @@ import argparse
4
 
5
  from common import defaults
6
 
7
- def gen_data_yaml(bcos, datapath='../data'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  names = [f"{d.name}" for d in bcos.values()]
9
  return f'''
10
- # this file is autogenerated by write_data.py
11
 
12
- train: {datapath}/squares
13
- val: {datapath}/squares
14
 
15
  nc: {len(bcos.keys())}
16
  names: {names}
@@ -20,8 +33,10 @@ if __name__ == '__main__':
20
  parser = argparse.ArgumentParser(description='creates a YOLOv5 data.yaml')
21
  parser.add_argument('csv', metavar='csv', type=str,
22
  help='csv file', default=defaults.MAIN_CSV_PATH)
 
 
23
  parser.add_argument('--data', metavar='data', type=str,
24
  help='data path', default=defaults.DATA_PATH)
25
  args = parser.parse_args()
26
  bcos = entity.read_entities(args.csv)
27
- print(gen_data_yaml(bcos, args.data))
 
4
 
5
  from common import defaults
6
 
7
+ YOLO_TEMPLATES = {
8
+ 5: '''
9
+ train: %%datapath%%/squares
10
+ val: %%datapath%%squares
11
+ ''',
12
+ 6: '''
13
+ train: %%datapath%%/squares/images
14
+ val: %%datapath%%/squares/images
15
+ test: %%datapath%%/squares/images
16
+
17
+ is_coco: False
18
+ '''
19
+ }
20
+
21
+ def gen_data_yaml(bcos, datapath='../data', version=6):
22
  names = [f"{d.name}" for d in bcos.values()]
23
  return f'''
24
+ # this file is autogenerated by write_data.py for YOLO version {version}
25
 
26
+ {YOLO_TEMPLATES[version].replace('%%datapath%%', datapath)}
 
27
 
28
  nc: {len(bcos.keys())}
29
  names: {names}
 
33
  parser = argparse.ArgumentParser(description='creates a YOLOv5 data.yaml')
34
  parser.add_argument('csv', metavar='csv', type=str,
35
  help='csv file', default=defaults.MAIN_CSV_PATH)
36
+ parser.add_argument('--version', metavar='version', type=int,
37
+ help='yolo version to target', default=6)
38
  parser.add_argument('--data', metavar='data', type=str,
39
  help='data path', default=defaults.DATA_PATH)
40
  args = parser.parse_args()
41
  bcos = entity.read_entities(args.csv)
42
+ print(gen_data_yaml(bcos, args.data, args.version))