31 def export(self, output_dir, model_name, img_width, img_height, num_classes):
32 print(f
"JEVOIS: exporting {model_name} to ONNX {img_width}x{img_height}-{num_classes}c")
33 x = torch.randn(1, 3, img_height, img_width, requires_grad=
False).to(self.
device)
34 txt_feats = torch.randn(1, num_classes, 512, requires_grad=
False).to(self.
device)
35 prefix = model_name.split(
'-')[0]
39 torch.onnx.export(self,
42 do_constant_folding=
True,
45 input_names=[
"images",
"txt_feats"],
46 output_names=[
"output"])
51 innames = [
"txt_feats"]
52 outnames = [
"/model.12/attn/Transpose_1_output_0",
53 "/model.15/attn/Transpose_1_output_0",
54 "/model.18/attn/Transpose_1_output_0",
55 "/model.21/attn/Transpose_1_output_0",
56 "/model.22/cv4.0/Div_output_0"]
57 outpath = os.path.join(output_dir, f
"{prefix}-jevois-{img_width}x{img_height}-{num_classes}c-txt.onnx")
58 print(f
" JEVOIS: extracting text model to {outpath}")
59 onnx.utils.extract_model(
"tmp_fullmodel.onnx", outpath, innames, outnames)
63 innames = [
"images"] + outnames
64 outnames = [
"/model.22/cv2.0/cv2.0.2/Conv_output_0",
"/model.22/cv4.0/Add_output_0",
65 "/model.22/cv2.1/cv2.1.2/Conv_output_0",
"/model.22/cv4.1/Add_output_0",
66 "/model.22/cv2.2/cv2.2.2/Conv_output_0",
"/model.22/cv4.2/Add_output_0"]
67 outpath = os.path.join(output_dir, f
"{prefix}-jevois-{img_width}x{img_height}-{num_classes}c-img.onnx")
68 print(f
" JEVOIS: extracting image model to {outpath}")
69 onnx.utils.extract_model(
"tmp_fullmodel.onnx", outpath, innames, outnames)
72 subprocess.run([
"python",
"-m",
"onnxruntime.tools.make_dynamic_shape_fixed",
73 "--input_name",
"/model.22/cv4.0/Div_output_0",
74 "--input_shape", f
"1,{num_classes},512", outpath, outpath])
78 innames = [
"images",
"txt_feats"]
79 outpath = os.path.join(output_dir, f
"{prefix}-jevois-{img_width}x{img_height}-{num_classes}c.onnx")
80 print(f
" JEVOIS: extracting combo model to {outpath}")
81 onnx.utils.extract_model(
"tmp_fullmodel.onnx", outpath, innames, outnames)
83 os.remove(
"tmp_fullmodel.onnx")