mirror of https://github.com/open-mmlab/mmpose
303 lines
9.6 KiB
Python
303 lines
9.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import logging
|
|
import mimetypes
|
|
import os
|
|
import time
|
|
from argparse import ArgumentParser
|
|
|
|
import cv2
|
|
import json_tricks as json
|
|
import mmcv
|
|
import mmengine
|
|
import numpy as np
|
|
from mmengine.logging import print_log
|
|
|
|
from mmpose.apis import inference_topdown
|
|
from mmpose.apis import init_model as init_pose_estimator
|
|
from mmpose.evaluation.functional import nms
|
|
from mmpose.registry import VISUALIZERS
|
|
from mmpose.structures import merge_data_samples, split_instances
|
|
from mmpose.utils import adapt_mmdet_pipeline
|
|
|
|
try:
|
|
from mmdet.apis import inference_detector, init_detector
|
|
has_mmdet = True
|
|
except (ImportError, ModuleNotFoundError):
|
|
has_mmdet = False
|
|
|
|
|
|
def process_one_image(args,
|
|
img,
|
|
detector,
|
|
pose_estimator,
|
|
visualizer=None,
|
|
show_interval=0):
|
|
"""Visualize predicted keypoints (and heatmaps) of one image."""
|
|
|
|
# predict bbox
|
|
det_result = inference_detector(detector, img)
|
|
pred_instance = det_result.pred_instances.cpu().numpy()
|
|
bboxes = np.concatenate(
|
|
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
|
|
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
|
|
pred_instance.scores > args.bbox_thr)]
|
|
bboxes = bboxes[nms(bboxes, args.nms_thr), :4]
|
|
|
|
# predict keypoints
|
|
pose_results = inference_topdown(pose_estimator, img, bboxes)
|
|
data_samples = merge_data_samples(pose_results)
|
|
|
|
# show the results
|
|
if isinstance(img, str):
|
|
img = mmcv.imread(img, channel_order='rgb')
|
|
elif isinstance(img, np.ndarray):
|
|
img = mmcv.bgr2rgb(img)
|
|
|
|
if visualizer is not None:
|
|
visualizer.add_datasample(
|
|
'result',
|
|
img,
|
|
data_sample=data_samples,
|
|
draw_gt=False,
|
|
draw_heatmap=args.draw_heatmap,
|
|
draw_bbox=args.draw_bbox,
|
|
show_kpt_idx=args.show_kpt_idx,
|
|
skeleton_style=args.skeleton_style,
|
|
show=args.show,
|
|
wait_time=show_interval,
|
|
kpt_thr=args.kpt_thr)
|
|
|
|
# if there is no instance detected, return None
|
|
return data_samples.get('pred_instances', None)
|
|
|
|
|
|
def main():
|
|
"""Visualize the demo images.
|
|
|
|
Using mmdet to detect the human.
|
|
"""
|
|
parser = ArgumentParser()
|
|
parser.add_argument('det_config', help='Config file for detection')
|
|
parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
|
|
parser.add_argument('pose_config', help='Config file for pose')
|
|
parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
|
|
parser.add_argument(
|
|
'--input', type=str, default='', help='Image/Video file')
|
|
parser.add_argument(
|
|
'--show',
|
|
action='store_true',
|
|
default=False,
|
|
help='whether to show img')
|
|
parser.add_argument(
|
|
'--output-root',
|
|
type=str,
|
|
default='',
|
|
help='root of the output img file. '
|
|
'Default not saving the visualization images.')
|
|
parser.add_argument(
|
|
'--save-predictions',
|
|
action='store_true',
|
|
default=False,
|
|
help='whether to save predicted results')
|
|
parser.add_argument(
|
|
'--device', default='cuda:0', help='Device used for inference')
|
|
parser.add_argument(
|
|
'--det-cat-id',
|
|
type=int,
|
|
default=0,
|
|
help='Category id for bounding box detection model')
|
|
parser.add_argument(
|
|
'--bbox-thr',
|
|
type=float,
|
|
default=0.3,
|
|
help='Bounding box score threshold')
|
|
parser.add_argument(
|
|
'--nms-thr',
|
|
type=float,
|
|
default=0.3,
|
|
help='IoU threshold for bounding box NMS')
|
|
parser.add_argument(
|
|
'--kpt-thr',
|
|
type=float,
|
|
default=0.3,
|
|
help='Visualizing keypoint thresholds')
|
|
parser.add_argument(
|
|
'--draw-heatmap',
|
|
action='store_true',
|
|
default=False,
|
|
help='Draw heatmap predicted by the model')
|
|
parser.add_argument(
|
|
'--show-kpt-idx',
|
|
action='store_true',
|
|
default=False,
|
|
help='Whether to show the index of keypoints')
|
|
parser.add_argument(
|
|
'--skeleton-style',
|
|
default='mmpose',
|
|
type=str,
|
|
choices=['mmpose', 'openpose'],
|
|
help='Skeleton style selection')
|
|
parser.add_argument(
|
|
'--radius',
|
|
type=int,
|
|
default=3,
|
|
help='Keypoint radius for visualization')
|
|
parser.add_argument(
|
|
'--thickness',
|
|
type=int,
|
|
default=1,
|
|
help='Link thickness for visualization')
|
|
parser.add_argument(
|
|
'--show-interval', type=int, default=0, help='Sleep seconds per frame')
|
|
parser.add_argument(
|
|
'--alpha', type=float, default=0.8, help='The transparency of bboxes')
|
|
parser.add_argument(
|
|
'--draw-bbox', action='store_true', help='Draw bboxes of instances')
|
|
|
|
assert has_mmdet, 'Please install mmdet to run the demo.'
|
|
|
|
args = parser.parse_args()
|
|
|
|
assert args.show or (args.output_root != '')
|
|
assert args.input != ''
|
|
assert args.det_config is not None
|
|
assert args.det_checkpoint is not None
|
|
|
|
output_file = None
|
|
if args.output_root:
|
|
mmengine.mkdir_or_exist(args.output_root)
|
|
output_file = os.path.join(args.output_root,
|
|
os.path.basename(args.input))
|
|
if args.input == 'webcam':
|
|
output_file += '.mp4'
|
|
|
|
if args.save_predictions:
|
|
assert args.output_root != ''
|
|
args.pred_save_path = f'{args.output_root}/results_' \
|
|
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'
|
|
|
|
# build detector
|
|
detector = init_detector(
|
|
args.det_config, args.det_checkpoint, device=args.device)
|
|
detector.cfg = adapt_mmdet_pipeline(detector.cfg)
|
|
|
|
# build pose estimator
|
|
pose_estimator = init_pose_estimator(
|
|
args.pose_config,
|
|
args.pose_checkpoint,
|
|
device=args.device,
|
|
cfg_options=dict(
|
|
model=dict(test_cfg=dict(output_heatmaps=args.draw_heatmap))))
|
|
|
|
# build visualizer
|
|
pose_estimator.cfg.visualizer.radius = args.radius
|
|
pose_estimator.cfg.visualizer.alpha = args.alpha
|
|
pose_estimator.cfg.visualizer.line_width = args.thickness
|
|
visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
|
|
# the dataset_meta is loaded from the checkpoint and
|
|
# then pass to the model in init_pose_estimator
|
|
visualizer.set_dataset_meta(
|
|
pose_estimator.dataset_meta, skeleton_style=args.skeleton_style)
|
|
|
|
if args.input == 'webcam':
|
|
input_type = 'webcam'
|
|
else:
|
|
input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
|
|
|
|
if input_type == 'image':
|
|
|
|
# inference
|
|
pred_instances = process_one_image(args, args.input, detector,
|
|
pose_estimator, visualizer)
|
|
|
|
if args.save_predictions:
|
|
pred_instances_list = split_instances(pred_instances)
|
|
|
|
if output_file:
|
|
img_vis = visualizer.get_image()
|
|
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)
|
|
|
|
elif input_type in ['webcam', 'video']:
|
|
|
|
if args.input == 'webcam':
|
|
cap = cv2.VideoCapture(0)
|
|
else:
|
|
cap = cv2.VideoCapture(args.input)
|
|
|
|
video_writer = None
|
|
pred_instances_list = []
|
|
frame_idx = 0
|
|
|
|
while cap.isOpened():
|
|
success, frame = cap.read()
|
|
frame_idx += 1
|
|
|
|
if not success:
|
|
break
|
|
|
|
# topdown pose estimation
|
|
pred_instances = process_one_image(args, frame, detector,
|
|
pose_estimator, visualizer,
|
|
0.001)
|
|
|
|
if args.save_predictions:
|
|
# save prediction results
|
|
pred_instances_list.append(
|
|
dict(
|
|
frame_id=frame_idx,
|
|
instances=split_instances(pred_instances)))
|
|
|
|
# output videos
|
|
if output_file:
|
|
frame_vis = visualizer.get_image()
|
|
|
|
if video_writer is None:
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
# the size of the image with visualization may vary
|
|
# depending on the presence of heatmaps
|
|
video_writer = cv2.VideoWriter(
|
|
output_file,
|
|
fourcc,
|
|
25, # saved fps
|
|
(frame_vis.shape[1], frame_vis.shape[0]))
|
|
|
|
video_writer.write(mmcv.rgb2bgr(frame_vis))
|
|
|
|
if args.show:
|
|
# press ESC to exit
|
|
if cv2.waitKey(5) & 0xFF == 27:
|
|
break
|
|
|
|
time.sleep(args.show_interval)
|
|
|
|
if video_writer:
|
|
video_writer.release()
|
|
|
|
cap.release()
|
|
|
|
else:
|
|
args.save_predictions = False
|
|
raise ValueError(
|
|
f'file {os.path.basename(args.input)} has invalid format.')
|
|
|
|
if args.save_predictions:
|
|
with open(args.pred_save_path, 'w') as f:
|
|
json.dump(
|
|
dict(
|
|
meta_info=pose_estimator.dataset_meta,
|
|
instance_info=pred_instances_list),
|
|
f,
|
|
indent='\t')
|
|
print(f'predictions have been saved at {args.pred_save_path}')
|
|
|
|
if output_file:
|
|
input_type = input_type.replace('webcam', 'video')
|
|
print_log(
|
|
f'the output {input_type} has been saved at {output_file}',
|
|
logger='current',
|
|
level=logging.INFO)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|