JeVoisBase  1.23
JeVois Smart Embedded Machine Vision Toolkit Base Modules
Share this page:
Loading...
Searching...
No Matches
yolo-jevois-export.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Adapted from https://github.com/ibaiGorordo/ONNX-YOLO-World-Open-Vocabulary-Object-Detection
4# MIT License
5
6from copy import deepcopy
7import torch
8import onnx
9import os
10from argparse import ArgumentParser
11from ultralytics import YOLOWorld
12from torchviz import make_dot
13import subprocess
14
15class ModelExporter(torch.nn.Module):
16 def __init__(self, yoloModel, device='cpu'):
17 super(ModelExporter, self).__init__()
18 model = deepcopy(yoloModel).to(device)
19 for p in model.parameters():
20 p.requires_grad = False
21 model.eval()
22 model.float()
23 model = model.fuse()
24
25 self.model = model
26 self.device = device
27
28 def forward(self, x, txt_feats):
29 return self.model.predict(x, txt_feats=txt_feats)
30
31 def export(self, output_dir, model_name, img_width, img_height, num_classes):
32 print(f"JEVOIS: exporting {model_name} to ONNX {img_width}x{img_height}-{num_classes}c")
33 x = torch.randn(1, 3, img_height, img_width, requires_grad=False).to(self.device)
34 txt_feats = torch.randn(1, num_classes, 512, requires_grad=False).to(self.device)
35 prefix = model_name.split('-')[0] # yolov8s, yolov8l, etc
36
37 # First export full model:
38 with torch.no_grad():
39 torch.onnx.export(self,
40 (x, txt_feats),
41 "tmp_fullmodel.onnx",
42 do_constant_folding=True,
43 opset_version=17, # was 17,
44 verbose=False,
45 input_names=["images", "txt_feats"],
46 output_names=["output"])
47
48 # JEVOIS: approach 1 (for 8-bit quantization): extract the text-processing portion, to be run
49 # only when class names are changed. This network computes 5 tensors from the CLIP embeddings, which will be
50 # multiplied with vision tensors as the model runs:
51 innames = ["txt_feats"]
52 outnames = ["/model.12/attn/Transpose_1_output_0",
53 "/model.15/attn/Transpose_1_output_0",
54 "/model.18/attn/Transpose_1_output_0",
55 "/model.21/attn/Transpose_1_output_0",
56 "/model.22/cv4.0/Div_output_0"]
57 outpath = os.path.join(output_dir, f"{prefix}-jevois-{img_width}x{img_height}-{num_classes}c-txt.onnx")
58 print(f" JEVOIS: extracting text model to {outpath}")
59 onnx.utils.extract_model("tmp_fullmodel.onnx", outpath, innames, outnames)
60
61 # JEVOIS: approach 1: Then extract the vision processing portion (to be quantized), taking the 5 tensors from
62 # text processing as inputs, in addition to the input image:
63 innames = ["images"] + outnames
64 outnames = ["/model.22/cv2.0/cv2.0.2/Conv_output_0", "/model.22/cv4.0/Add_output_0",
65 "/model.22/cv2.1/cv2.1.2/Conv_output_0", "/model.22/cv4.1/Add_output_0",
66 "/model.22/cv2.2/cv2.2.2/Conv_output_0", "/model.22/cv4.2/Add_output_0"]
67 outpath = os.path.join(output_dir, f"{prefix}-jevois-{img_width}x{img_height}-{num_classes}c-img.onnx")
68 print(f" JEVOIS: extracting image model to {outpath}")
69 onnx.utils.extract_model("tmp_fullmodel.onnx", outpath, innames, outnames)
70
71 # The Div input has variable size; fix it to 1xCx512:
72 subprocess.run(["python", "-m", "onnxruntime.tools.make_dynamic_shape_fixed",
73 "--input_name", "/model.22/cv4.0/Div_output_0",
74 "--input_shape", f"1,{num_classes},512", outpath, outpath])
75
76 # JEVOIS: approach 2 (for 16-bit quantization, slower at runtime): only truncate the model to yield split
77 # outputs (6 tensors for boxes, class scores at 3 strides), take image and CLIP embeddings as input:
78 innames = ["images", "txt_feats"]
79 outpath = os.path.join(output_dir, f"{prefix}-jevois-{img_width}x{img_height}-{num_classes}c.onnx")
80 print(f" JEVOIS: extracting combo model to {outpath}")
81 onnx.utils.extract_model("tmp_fullmodel.onnx", outpath, innames, outnames)
82
83 os.remove("tmp_fullmodel.onnx")
84
85
86def main():
87 parser = ArgumentParser()
88 parser.add_argument("--img_width", type=int, default=512)
89 parser.add_argument("--img_height", type=int, default=288)
90 parser.add_argument("--num_classes", type=int, default=-1)
91 parser.add_argument("--model_name", type=str, default="yolov8s-worldv2.pt")
92 parser.add_argument("--output_dir", type=str, default="")
93 parser.add_argument("--device", type=str, default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
94
95 args = parser.parse_args()
96 img_width = args.img_width
97 img_height = args.img_height
98 num_classes = args.num_classes
99 model_name = args.model_name
100 output_dir = args.output_dir
101 device = args.device
102
103 if num_classes > 0:
104 nclass = [num_classes]
105 else:
106 nclass = [1, 8, 16, 32, 64]
107
108 for nc in nclass:
109 yoloModel = YOLOWorld(model_name)
110 yoloModel.set_classes(["person"] * nc)
111
112 #print(yoloModel)
113
114 export_model = ModelExporter(yoloModel.model, device)
115 export_model.export(output_dir, model_name, img_width, img_height, nc)
116
117if __name__ == "__main__":
118 main()
__init__(self, yoloModel, device='cpu')
export(self, output_dir, model_name, img_width, img_height, num_classes)