12from dasiamrpn
import DaSiamRPN
15 if v.lower()
in [
'on',
'yes',
'true',
'y',
't']:
17 elif v.lower()
in [
'off',
'no',
'false',
'n',
'f']:
20 raise NotImplementedError
22parser = argparse.ArgumentParser(
23 description=
"Distractor-aware Siamese Networks for Visual Object Tracking (https://arxiv.org/abs/1808.06048)")
24parser.add_argument(
'--input',
'-i', type=str, help=
'Path to the input video. Omit for using default camera.')
25parser.add_argument(
'--model_path', type=str, default=
'object_tracking_dasiamrpn_model_2021nov.onnx', help=
'Path to dasiamrpn_model.onnx.')
26parser.add_argument(
'--kernel_cls1_path', type=str, default=
'object_tracking_dasiamrpn_kernel_cls1_2021nov.onnx', help=
'Path to dasiamrpn_kernel_cls1.onnx.')
27parser.add_argument(
'--kernel_r1_path', type=str, default=
'object_tracking_dasiamrpn_kernel_r1_2021nov.onnx', help=
'Path to dasiamrpn_kernel_r1.onnx.')
28parser.add_argument(
'--save',
'-s', type=str2bool, default=
False, help=
'Set true to save results. This flag is invalid when using camera.')
29parser.add_argument(
'--vis',
'-v', type=str2bool, default=
True, help=
'Set true to open a window for result visualization. This flag is invalid when using camera.')
30args = parser.parse_args()
32def visualize(image, bbox, score, isLocated, fps=None, box_color=(0, 255, 0),text_color=(0, 255, 0), fontScale = 1, fontSize = 1):
34 h, w, _ = output.shape
37 cv.putText(output,
'FPS: {:.2f}'.format(fps), (0, 30), cv.FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize)
39 if isLocated
and score >= 0.6:
42 cv.rectangle(output, (x, y), (x+w, y+h), box_color, 2)
43 cv.putText(output,
'{:.2f}'.format(score), (x, y+20), cv.FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize)
45 text_size, baseline = cv.getTextSize(
'Target lost!', cv.FONT_HERSHEY_DUPLEX, fontScale, fontSize)
46 text_x =
int((w - text_size[0]) / 2)
47 text_y =
int((h - text_size[1]) / 2)
48 cv.putText(output,
'Target lost!', (text_x, text_y), cv.FONT_HERSHEY_DUPLEX, fontScale, (0, 0, 255), fontSize)
52if __name__ ==
'__main__':
55 model_path=args.model_path,
56 kernel_cls1_path=args.kernel_cls1_path,
57 kernel_r1_path=args.kernel_r1_path
62 if args.input
is None:
65 video = cv.VideoCapture(_input)
68 has_frame, first_frame = video.read()
70 print(
'No frames grabbed!')
72 first_frame_copy = first_frame.copy()
73 cv.putText(first_frame_copy,
"1. Drag a bounding box to track.", (0, 15), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0))
74 cv.putText(first_frame_copy,
"2. Press ENTER to confirm", (0, 35), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0))
75 roi = cv.selectROI(
'DaSiamRPN Demo', first_frame_copy)
76 print(
"Selected ROI: {}".format(roi))
79 model.init(first_frame, roi)
83 while cv.waitKey(1) < 0:
84 has_frame, frame = video.read()
90 isLocated, bbox, score = model.infer(frame)
93 frame =
visualize(frame, bbox, score, isLocated, fps=tm.getFPS())
94 cv.imshow(
'DaSiamRPN Demo', frame)
visualize(image, results, box_color=(0, 255, 0), text_color=(0, 0, 255), fps=None)