JeVoisBase  1.22
JeVois Smart Embedded Machine Vision Toolkit Base Modules
Share this page:
Loading...
Searching...
No Matches
quantize-ort.py
Go to the documentation of this file.
1# This file is part of OpenCV Zoo project.
2# It is subject to the license terms in the LICENSE file found in the same directory.
3#
4# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
5# Third party copyrights are property of their respective owners.
6
7import os
8import sys
9import numpy as ny
10import cv2 as cv
11
12import onnx
13from onnx import version_converter
14import onnxruntime
15from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
16
17from transform import Compose, Resize, CenterCrop, Normalize, ColorConvert
18
19class DataReader(CalibrationDataReader):
20 def __init__(self, model_path, image_dir, transforms):
21 model = onnx.load(model_path)
22 self.input_name = model.graph.input[0].name
23 self.transforms = transforms
24 self.data = self.get_calibration_data(image_dir)
25 self.enum_data_dicts = iter([{self.input_name: x} for x in self.data])
26
27 def get_next(self):
28 return next(self.enum_data_dicts, None)
29
30 def get_calibration_data(self, image_dir):
31 blobs = []
32 for image_name in os.listdir(image_dir):
33 image_name_suffix = image_name.split('.')[-1].lower()
34 if image_name_suffix != 'jpg' and image_name_suffix != 'jpeg':
35 continue
36 img = cv.imread(os.path.join(image_dir, image_name))
37 img = self.transforms(img)
38 blob = cv.dnn.blobFromImage(img)
39 blobs.append(blob)
40 return blobs
41
43 def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8'):
44 self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
45
46 self.model_path = model_path
47 self.calibration_image_dir = calibration_image_dir
48 self.transforms = transforms
49 self.per_channel = per_channel
50 self.act_type = act_type
51 self.wt_type = wt_type
52
53 # data reader
55
56 def check_opset(self, convert=True):
57 model = onnx.load(self.model_path)
58 if model.opset_import[0].version != 11:
59 print('\tmodel opset version: {}. Converting to opset 11'.format(model.opset_import[0].version))
60 # convert opset version to 11
61 model_opset11 = version_converter.convert_version(model, 11)
62 # save converted model
63 output_name = '{}-opset11.onnx'.format(self.model_path[:-5])
64 onnx.save_model(model_opset11, output_name)
65 # update model_path for quantization
66 self.model_path = output_name
67
68 def run(self):
69 print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
70 self.check_opset()
71 output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
72 quantize_static(self.model_path, output_name, self.dr,
73 per_channel=self.per_channel,
74 weight_type=self.type_dict[self.wt_type],
75 activation_type=self.type_dict[self.act_type])
76 os.remove('augmented_model.onnx')
77 os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
78 print('\tQuantized model saved to {}'.format(output_name))
79
80models=dict(
81 yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2022mar.onnx',
82 calibration_image_dir='../../benchmark/data/face_detection',
83 transforms=Compose([Resize(size=(160, 120))])),
84 sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
85 calibration_image_dir='../../benchmark/data/face_recognition',
86 transforms=Compose([Resize(size=(112, 112))])),
87 pphumenseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct.onnx',
88 calibration_image_dir='../../benchmark/data/human_segmentation',
89 transforms=Compose([Resize(size=(192, 192))])),
90 ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx',
91 calibration_image_dir='../../benchmark/data/image_classification',
92 transforms=Compose([Resize(size=(224, 224))])),
93 # TBD: DaSiamRPN
94 youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx',
95 calibration_image_dir='../../benchmark/data/person_reid',
96 transforms=Compose([Resize(size=(128, 256))])),
97 # TBD: DB-EN & DB-CN
98 crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx',
99 calibration_image_dir='../../benchmark/data/text',
100 transforms=Compose([Resize(size=(100, 32)), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])),
101 crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx',
102 calibration_image_dir='../../benchmark/data/text',
103 transforms=Compose([Resize(size=(100, 32))]))
104)
105
106if __name__ == '__main__':
107 selected_models = []
108 for i in range(1, len(sys.argv)):
109 selected_models.append(sys.argv[i])
110 if not selected_models:
111 selected_models = list(models.keys())
112 print('Models to be quantized: {}'.format(str(selected_models)))
113
114 for selected_model_name in selected_models:
115 q = models[selected_model_name]
116 q.run()
117
__init__(self, model_path, image_dir, transforms)
get_calibration_data(self, image_dir)
check_opset(self, convert=True)
__init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8')