JeVoisBase  1.22
JeVois Smart Embedded Machine Vision Toolkit Base Modules
Share this page:
Loading...
Searching...
No Matches
download_data.py
Go to the documentation of this file.
1import hashlib
2import os
3import sys
4import tarfile
5import zipfile
6import requests
7import os.path as osp
8
9from urllib.request import urlopen
10from urllib.parse import urlparse
11
12
14 MB = 1024*1024
15 BUFSIZE = 10*MB
16
17 def __init__(self, **kwargs):
18 self._name = kwargs.pop('name')
19 self._url = kwargs.pop('url', None)
20 self._filename = kwargs.pop('filename')
21 self._sha = kwargs.pop('sha', None)
22 self._saveTo = kwargs.pop('saveTo', './data')
23 self._extractTo = kwargs.pop('extractTo', './data')
24
25 def __str__(self):
26 return 'Downloader for <{}>'.format(self._name)
27
28 def printRequest(self, r):
29 def getMB(r):
30 d = dict(r.info())
31 for c in ['content-length', 'Content-Length']:
32 if c in d:
33 return int(d[c]) / self.MB
34 return '<unknown>'
35 print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r)))
36
37 def verifyHash(self):
38 if not self._sha:
39 return False
40 sha = hashlib.sha1()
41 try:
42 with open(osp.join(self._saveTo, self._filename), 'rb') as f:
43 while True:
44 buf = f.read(self.BUFSIZE)
45 if not buf:
46 break
47 sha.update(buf)
48 if self._sha != sha.hexdigest():
49 print(' actual {}'.format(sha.hexdigest()))
50 print(' expect {}'.format(self._sha))
51 return self._sha == sha.hexdigest()
52 except Exception as e:
53 print(' catch {}'.format(e))
54
55 def get(self):
56 print(' {}: {}'.format(self._name, self._filename))
57 if self.verifyHash():
58 print(' hash match - skipping download')
59 else:
60 basedir = os.path.dirname(self._saveTo)
61 if basedir and not os.path.exists(basedir):
62 print(' creating directory: ' + basedir)
63 os.makedirs(basedir, exist_ok=True)
64
65 print(' hash check failed - downloading')
66 if 'drive.google.com' in self._url:
67 urlquery = urlparse(self._url).query.split('&')
68 for q in urlquery:
69 if 'id=' in q:
70 gid = q[3:]
71 sz = GDrive(gid)(osp.join(self._saveTo, self._filename))
72 print(' size = %.2f Mb' % (sz / (1024.0 * 1024)))
73 else:
74 print(' get {}'.format(self._url))
75 self.download()
76
77 # Verify hash after download
78 print(' done')
79 print(' file {}'.format(self._filename))
80 if self.verifyHash():
81 print(' hash match - extracting')
82 else:
83 print(' hash check failed - exiting')
84
85 # Extract
86 if '.zip' in self._filename:
87 print(' extracting - ', end='')
88 self.extract()
89 print('done')
90
91 return True
92
93 def download(self):
94 try:
95 r = urlopen(self._url, timeout=60)
96 self.printRequest(r)
97 self.save(r)
98 except Exception as e:
99 print(' catch {}'.format(e))
100
101 def extract(self):
102 fileLocation = os.path.join(self._saveTo, self._filename)
103 try:
104 if self._filename.endswith('.zip'):
105 with zipfile.ZipFile(fileLocation) as f:
106 for member in f.namelist():
107 path = osp.join(self._extractTo, member)
108 if osp.exists(path) or osp.isfile(path):
109 continue
110 else:
111 f.extract(member, self._extractTo)
112 except Exception as e:
113 print((' catch {}'.format(e)))
114
115 def save(self, r):
116 with open(self._filename, 'wb') as f:
117 print(' progress ', end='')
118 sys.stdout.flush()
119 while True:
120 buf = r.read(self.BUFSIZE)
121 if not buf:
122 break
123 f.write(buf)
124 print('>', end='')
125 sys.stdout.flush()
126
127
128def GDrive(gid):
129 def download_gdrive(dst):
130 session = requests.Session() # re-use cookies
131
132 URL = "https://docs.google.com/uc?export=download"
133 response = session.get(URL, params = { 'id' : gid }, stream = True)
134
135 def get_confirm_token(response): # in case of large files
136 for key, value in response.cookies.items():
137 if key.startswith('download_warning'):
138 return value
139 return None
140 token = get_confirm_token(response)
141
142 if token:
143 params = { 'id' : gid, 'confirm' : token }
144 response = session.get(URL, params = params, stream = True)
145
146 BUFSIZE = 1024 * 1024
147 PROGRESS_SIZE = 10 * 1024 * 1024
148
149 sz = 0
150 progress_sz = PROGRESS_SIZE
151 with open(dst, "wb") as f:
152 for chunk in response.iter_content(BUFSIZE):
153 if not chunk:
154 continue # keep-alive
155
156 f.write(chunk)
157 sz += len(chunk)
158 if sz >= progress_sz:
159 progress_sz += PROGRESS_SIZE
160 print('>', end='')
161 sys.stdout.flush()
162 print('')
163 return sz
164 return download_gdrive
165
166# Data will be downloaded and extracted to ./data by default
167data_downloaders = dict(
168 face_detection=Downloader(name='face_detection',
169 url='https://drive.google.com/u/0/uc?id=1lOAliAIeOv4olM65YDzE55kn6XjiX2l6&export=download',
170 sha='0ba67a9cfd60f7fdb65cdb7c55a1ce76c1193df1',
171 filename='face_detection.zip'),
172 face_recognition=Downloader(name='face_recognition',
173 url='https://drive.google.com/u/0/uc?id=1BRIozREIzqkm_aMQ581j93oWoS-6TLST&export=download',
174 sha='03892b9036c58d9400255ff73858caeec1f46609',
175 filename='face_recognition.zip'),
176 text=Downloader(name='text',
177 url='https://drive.google.com/u/0/uc?id=1lTQdZUau7ujHBqp0P6M1kccnnJgO-dRj&export=download',
178 sha='a40cf095ceb77159ddd2a5902f3b4329696dd866',
179 filename='text.zip'),
180 image_classification=Downloader(name='image_classification',
181 url='https://drive.google.com/u/0/uc?id=1qcsrX3CIAGTooB-9fLKYwcvoCuMgjzGU&export=download',
182 sha='987546f567f9f11d150eea78951024b55b015401',
183 filename='image_classification.zip'),
184 human_segmentation=Downloader(name='human_segmentation',
185 url='https://drive.google.com/u/0/uc?id=1Kh0qXcAZCEaqwavbUZubhRwrn_8zY7IL&export=download',
186 sha='ac0eedfd8568570cad135acccd08a134257314d0',
187 filename='human_segmentation.zip'),
188 qrcode=Downloader(name='qrcode',
189 url='https://drive.google.com/u/0/uc?id=1_OXB7eiCIYO335ewkT6EdAeXyriFlq_H&export=download',
190 sha='ac01c098934a353ca1545b5266de8bb4f176d1b3',
191 filename='qrcode.zip'),
192 object_tracking=Downloader(name='object_tracking',
193 url='https://drive.google.com/u/0/uc?id=1_cw5pUmTF-XmQVcQAI8fIp-Ewi2oMYIn&export=download',
194 sha='0bdb042632a245270013713bc48ad35e9221f3bb',
195 filename='object_tracking.zip'),
196 person_reid=Downloader(name='person_reid',
197 url='https://drive.google.com/u/0/uc?id=1G8FkfVo5qcuyMkjSs4EA6J5e16SWDGI2&export=download',
198 sha='5b741fbf34c1fbcf59cad8f2a65327a5899e66f1',
199 filename='person_reid.zip'),
200 palm_detection=Downloader(name='palm_detection',
201 url='https://drive.google.com/u/0/uc?id=1qScOzehV8OIzJJLuD_LMvZq15YcWd_VV&export=download',
202 sha='c0d4f811d38c6f833364b9196a719307598213a1',
203 filename='palm_detection.zip'),
204 license_plate_detection=Downloader(name='license_plate_detection',
205 url='https://drive.google.com/u/0/uc?id=1cf9MEyUqMMy8lLeDGd1any6tM_SsSmny&export=download',
206 sha='997acb143ddc4531e6e41365fb7ad4722064564c',
207 filename='license_plate_detection.zip'),
208)
209
210if __name__ == '__main__':
211 selected_data_names = []
212 for i in range(1, len(sys.argv)):
213 selected_data_names.append(sys.argv[i])
214 if not selected_data_names:
215 selected_data_names = list(data_downloaders.keys())
216 print('Data will be downloaded: {}'.format(str(selected_data_names)))
217
218 download_failed = []
219 for selected_data_name in selected_data_names:
220 downloader = data_downloaders[selected_data_name]
221 if not downloader.get():
222 download_failed.append(downloader._name)
223
224 if download_failed:
225 print('Data have not been downloaded: {}'.format(str(download_failed)))
__init__(self, **kwargs)