JeVoisBase  1.20
JeVois Smart Embedded Machine Vision Toolkit Base Modules
Share this page:
quantize-ort.py
Go to the documentation of this file.
1 # This file is part of OpenCV Zoo project.
2 # It is subject to the license terms in the LICENSE file found in the same directory.
3 #
4 # Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
5 # Third party copyrights are property of their respective owners.
6 
7 import os
8 import sys
9 import numpy as ny
10 import cv2 as cv
11 
12 import onnx
13 from onnx import version_converter
14 import onnxruntime
15 from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
16 
17 from transform import Compose, Resize, CenterCrop, Normalize, ColorConvert
18 
19 class DataReader(CalibrationDataReader):
20  def __init__(self, model_path, image_dir, transforms):
21  model = onnx.load(model_path)
22  self.input_name = model.graph.input[0].name
23  self.transforms = transforms
24  self.data = self.get_calibration_data(image_dir)
25  self.enum_data_dicts = iter([{self.input_name: x} for x in self.data])
26 
27  def get_next(self):
28  return next(self.enum_data_dicts, None)
29 
30  def get_calibration_data(self, image_dir):
31  blobs = []
32  for image_name in os.listdir(image_dir):
33  image_name_suffix = image_name.split('.')[-1].lower()
34  if image_name_suffix != 'jpg' and image_name_suffix != 'jpeg':
35  continue
36  img = cv.imread(os.path.join(image_dir, image_name))
37  img = self.transforms(img)
38  blob = cv.dnn.blobFromImage(img)
39  blobs.append(blob)
40  return blobs
41 
42 class Quantize:
43  def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8'):
44  self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
45 
46  self.model_path = model_path
47  self.calibration_image_dir = calibration_image_dir
48  self.transforms = transforms
49  self.per_channel = per_channel
50  self.act_type = act_type
51  self.wt_type = wt_type
52 
53  # data reader
55 
56  def check_opset(self, convert=True):
57  model = onnx.load(self.model_path)
58  if model.opset_import[0].version != 11:
59  print('\tmodel opset version: {}. Converting to opset 11'.format(model.opset_import[0].version))
60  # convert opset version to 11
61  model_opset11 = version_converter.convert_version(model, 11)
62  # save converted model
63  output_name = '{}-opset11.onnx'.format(self.model_path[:-5])
64  onnx.save_model(model_opset11, output_name)
65  # update model_path for quantization
66  self.model_path = output_name
67 
68  def run(self):
69  print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
70  self.check_opset()
71  output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
72  quantize_static(self.model_path, output_name, self.dr,
73  per_channel=self.per_channel,
74  weight_type=self.type_dict[self.wt_type],
75  activation_type=self.type_dict[self.act_type])
76  os.remove('augmented_model.onnx')
77  os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
78  print('\tQuantized model saved to {}'.format(output_name))
79 
80 models=dict(
81  yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2022mar.onnx',
82  calibration_image_dir='../../benchmark/data/face_detection',
83  transforms=Compose([Resize(size=(160, 120))])),
84  sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
85  calibration_image_dir='../../benchmark/data/face_recognition',
86  transforms=Compose([Resize(size=(112, 112))])),
87  pphumenseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct.onnx',
88  calibration_image_dir='../../benchmark/data/human_segmentation',
89  transforms=Compose([Resize(size=(192, 192))])),
90  ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx',
91  calibration_image_dir='../../benchmark/data/image_classification',
92  transforms=Compose([Resize(size=(224, 224))])),
93  # TBD: DaSiamRPN
94  youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx',
95  calibration_image_dir='../../benchmark/data/person_reid',
96  transforms=Compose([Resize(size=(128, 256))])),
97  # TBD: DB-EN & DB-CN
98  crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx',
99  calibration_image_dir='../../benchmark/data/text',
100  transforms=Compose([Resize(size=(100, 32)), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])),
101  crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx',
102  calibration_image_dir='../../benchmark/data/text',
103  transforms=Compose([Resize(size=(100, 32))]))
104 )
105 
106 if __name__ == '__main__':
107  selected_models = []
108  for i in range(1, len(sys.argv)):
109  selected_models.append(sys.argv[i])
110  if not selected_models:
111  selected_models = list(models.keys())
112  print('Models to be quantized: {}'.format(str(selected_models)))
113 
114  for selected_model_name in selected_models:
115  q = models[selected_model_name]
116  q.run()
117 
quantize-ort.DataReader.get_calibration_data
def get_calibration_data(self, image_dir)
Definition: quantize-ort.py:30
demo.str
str
Definition: demo.py:35
quantize-ort.Quantize
Definition: quantize-ort.py:42
quantize-ort.Quantize.run
def run(self)
Definition: quantize-ort.py:68
transform.ColorConvert
Definition: transform.py:54
transform.Resize
Definition: transform.py:20
quantize-ort.Quantize.transforms
transforms
Definition: quantize-ort.py:83
transform.Compose
Definition: transform.py:11
quantize-ort.DataReader.enum_data_dicts
enum_data_dicts
Definition: quantize-ort.py:25
quantize-ort.Quantize.model_path
model_path
Definition: quantize-ort.py:81
quantize-ort.DataReader.__init__
def __init__(self, model_path, image_dir, transforms)
Definition: quantize-ort.py:20
quantize-ort.Quantize.act_type
act_type
Definition: quantize-ort.py:50
quantize-ort.DataReader.input_name
input_name
Definition: quantize-ort.py:22
quantize-ort.Quantize.wt_type
wt_type
Definition: quantize-ort.py:51
quantize-ort.Quantize.calibration_image_dir
string calibration_image_dir
Definition: quantize-ort.py:82
quantize-ort.DataReader
Definition: quantize-ort.py:19
quantize-ort.DataReader.get_next
def get_next(self)
Definition: quantize-ort.py:27
quantize-ort.Quantize.type_dict
type_dict
Definition: quantize-ort.py:44
quantize-ort.Quantize.check_opset
def check_opset(self, convert=True)
Definition: quantize-ort.py:56
quantize-ort.DataReader.transforms
transforms
Definition: quantize-ort.py:23
quantize-ort.Quantize.dr
dr
Definition: quantize-ort.py:54
quantize-ort.Quantize.__init__
def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8')
Definition: quantize-ort.py:43
quantize-ort.DataReader.data
data
Definition: quantize-ort.py:24
quantize-ort.Quantize.per_channel
per_channel
Definition: quantize-ort.py:49