Dirake commited on
Commit
9c29da9
·
verified ·
1 Parent(s): 982c544

Create convert.py

Browse files
Files changed (1) hide show
  1. convert.py +77 -0
convert.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from xml.dom import minidom
4
+ import os
5
+ import glob
6
+
7
+ lut={}
8
+ lut["pikachu"]=0
9
+ lut["charmander"]=1
10
+ lut["bulbasaur"]=2
11
+ lut["squirtle"]=3
12
+ lut["eevee"]=4
13
+ lut["other"]=5
14
+ lut["jigglypuff"]=6
15
+
16
+
17
+ def convert_coordinates(size, box):
18
+ if(size[0]==0 or size[1]==0):
19
+ return (0,0,0,0)
20
+ dw = 1.0/size[0]
21
+ dh = 1.0/size[1]
22
+ x = (box[0]+box[1])/2.0
23
+ y = (box[2]+box[3])/2.0
24
+ w = box[1]-box[0]
25
+ h = box[3]-box[2]
26
+ x = x*dw
27
+ w = w*dw
28
+ y = y*dh
29
+ h = h*dh
30
+ return (x,y,w,h)
31
+
32
+
33
+ def convert_xml2yolo(lut):
34
+
35
+ for fname in glob.glob("*.xml"):
36
+
37
+ xmldoc = minidom.parse(fname)
38
+
39
+ fname_out = (fname[:-4]+'.txt')
40
+
41
+ with open(fname_out, "w") as f:
42
+
43
+ itemlist = xmldoc.getElementsByTagName('object')
44
+ size = xmldoc.getElementsByTagName('size')[0]
45
+ width = int((size.getElementsByTagName('width')[0]).firstChild.data)
46
+ height = int((size.getElementsByTagName('height')[0]).firstChild.data)
47
+
48
+ for item in itemlist:
49
+ # get class label
50
+ classid = (item.getElementsByTagName('name')[0]).firstChild.data
51
+ if classid in lut:
52
+ label_str = str(lut[classid])
53
+ else:
54
+ label_str = "-1"
55
+ print ("warning: label '%s' not in look-up table" % classid)
56
+
57
+ # get bbox coordinates
58
+ xmin = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('xmin')[0]).firstChild.data
59
+ ymin = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('ymin')[0]).firstChild.data
60
+ xmax = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('xmax')[0]).firstChild.data
61
+ ymax = ((item.getElementsByTagName('bndbox')[0]).getElementsByTagName('ymax')[0]).firstChild.data
62
+ b = (float(xmin), float(xmax), float(ymin), float(ymax))
63
+ bb = convert_coordinates((width,height), b)
64
+ #print(bb)
65
+
66
+ f.write(label_str + " " + " ".join([("%.6f" % a) for a in bb]) + '\n')
67
+
68
+ print ("wrote %s" % fname_out)
69
+
70
+
71
+
72
+ def main():
73
+ convert_xml2yolo(lut)
74
+
75
+
76
+ if __name__ == '__main__':
77
+ main()