mirror of https://github.com/open-mmlab/mmpose
713 lines
25 KiB
Plaintext
713 lines
25 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6d999c38-2087-4250-b6a4-a30cf8b44ec0",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T13:11:38.997916Z",
|
|
"iopub.status.busy": "2023-07-05T13:11:38.997587Z",
|
|
"iopub.status.idle": "2023-07-05T13:11:39.001928Z",
|
|
"shell.execute_reply": "2023-07-05T13:11:39.001429Z",
|
|
"shell.execute_reply.started": "2023-07-05T13:11:38.997898Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import os.path as osp\n",
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"import mmcv\n",
|
|
"import cv2\n",
|
|
"from mmengine.utils import track_iter_progress"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bfa9bf9b-dc2c-4803-a034-8ae8778113e0",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:42:15.884465Z",
|
|
"iopub.status.busy": "2023-07-05T12:42:15.884167Z",
|
|
"iopub.status.idle": "2023-07-05T12:42:19.774569Z",
|
|
"shell.execute_reply": "2023-07-05T12:42:19.774020Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:42:15.884448Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# download example videos\n",
|
|
"from mmengine.utils import mkdir_or_exist\n",
|
|
"mkdir_or_exist('resources')\n",
|
|
"! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tom.mp4 \n",
|
|
"! wget -O resources/teacher_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/idol_producer.mp4 \n",
|
|
"# ! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tsinghua_30fps.mp4 \n",
|
|
"\n",
|
|
"student_video = 'resources/student_video.mp4'\n",
|
|
"teacher_video = 'resources/teacher_video.mp4'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "652b6b91-e1c0-461b-90e5-653bc35ec380",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:42:20.693931Z",
|
|
"iopub.status.busy": "2023-07-05T12:42:20.693353Z",
|
|
"iopub.status.idle": "2023-07-05T12:43:14.533985Z",
|
|
"shell.execute_reply": "2023-07-05T12:43:14.533431Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:42:20.693910Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# convert the fps of videos to 30\n",
|
|
"from mmcv import VideoReader\n",
|
|
"\n",
|
|
"if VideoReader(student_video) != 30:\n",
|
|
" # ffmpeg is required to convert the video fps\n",
|
|
" # which can be installed via `sudo apt install ffmpeg` on ubuntu\n",
|
|
" student_video_30fps = student_video.replace(\n",
|
|
" f\".{student_video.rsplit('.', 1)[1]}\",\n",
|
|
" f\"_30fps.{student_video.rsplit('.', 1)[1]}\"\n",
|
|
" )\n",
|
|
" !ffmpeg -i {student_video} -vf \"minterpolate='fps=30'\" {student_video_30fps}\n",
|
|
" student_video = student_video_30fps\n",
|
|
" \n",
|
|
"if VideoReader(teacher_video) != 30:\n",
|
|
" teacher_video_30fps = teacher_video.replace(\n",
|
|
" f\".{teacher_video.rsplit('.', 1)[1]}\",\n",
|
|
" f\"_30fps.{teacher_video.rsplit('.', 1)[1]}\"\n",
|
|
" )\n",
|
|
" !ffmpeg -i {teacher_video} -vf \"minterpolate='fps=30'\" {teacher_video_30fps}\n",
|
|
" teacher_video = teacher_video_30fps "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6a4e141d-ee4a-4e06-a380-230418c9b936",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:45:01.672054Z",
|
|
"iopub.status.busy": "2023-07-05T12:45:01.671727Z",
|
|
"iopub.status.idle": "2023-07-05T12:45:02.417026Z",
|
|
"shell.execute_reply": "2023-07-05T12:45:02.416567Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:45:01.672035Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# init pose estimator\n",
|
|
"from mmpose.apis.inferencers import Pose2DInferencer\n",
|
|
"pose_estimator = Pose2DInferencer(\n",
|
|
" 'rtmpose-t_8xb256-420e_aic-coco-256x192',\n",
|
|
" det_model='configs/rtmdet-nano_one-person.py',\n",
|
|
" det_weights='https://download.openmmlab.com/mmpose/v1/projects/' \n",
|
|
" 'rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth'\n",
|
|
")\n",
|
|
"pose_estimator.model.test_cfg['flip_test'] = False"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "879ba5c0-4d2d-4cca-92d7-d4f94e04a821",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:45:05.192437Z",
|
|
"iopub.status.busy": "2023-07-05T12:45:05.191982Z",
|
|
"iopub.status.idle": "2023-07-05T12:45:05.197379Z",
|
|
"shell.execute_reply": "2023-07-05T12:45:05.196780Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:45:05.192417Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@torch.no_grad()\n",
|
|
"def get_keypoints_from_frame(image, pose_estimator):\n",
|
|
" \"\"\"Extract keypoints from a single video frame.\"\"\"\n",
|
|
"\n",
|
|
" det_results = pose_estimator.detector(\n",
|
|
" image, return_datasample=True)['predictions']\n",
|
|
" pred_instance = det_results[0].pred_instances.numpy()\n",
|
|
"\n",
|
|
" if len(pred_instance) == 0 or pred_instance.scores[0] < 0.2:\n",
|
|
" return np.zeros((1, 17, 3), dtype=np.float32)\n",
|
|
"\n",
|
|
" data_info = dict(\n",
|
|
" img=image,\n",
|
|
" bbox=pred_instance.bboxes[:1],\n",
|
|
" bbox_score=pred_instance.scores[:1])\n",
|
|
"\n",
|
|
" data_info.update(pose_estimator.model.dataset_meta)\n",
|
|
" data = pose_estimator.collate_fn(\n",
|
|
" [pose_estimator.pipeline(data_info)])\n",
|
|
"\n",
|
|
" # custom forward\n",
|
|
" data = pose_estimator.model.data_preprocessor(data, False)\n",
|
|
" feats = pose_estimator.model.extract_feat(data['inputs'])\n",
|
|
" pred_instances = pose_estimator.model.head.predict(\n",
|
|
" feats,\n",
|
|
" data['data_samples'],\n",
|
|
" test_cfg=pose_estimator.model.test_cfg)[0]\n",
|
|
" keypoints = np.concatenate(\n",
|
|
" (pred_instances.keypoints, pred_instances.keypoint_scores[...,\n",
|
|
" None]),\n",
|
|
" axis=-1)\n",
|
|
"\n",
|
|
" return keypoints "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "31e5bd4c-4c2b-4fe0-b64c-1afed67b7688",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:47:55.564788Z",
|
|
"iopub.status.busy": "2023-07-05T12:47:55.564450Z",
|
|
"iopub.status.idle": "2023-07-05T12:49:37.222662Z",
|
|
"shell.execute_reply": "2023-07-05T12:49:37.222028Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:47:55.564770Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# pose estimation in two videos\n",
|
|
"student_poses, teacher_poses = [], []\n",
|
|
"for frame in VideoReader(student_video):\n",
|
|
" student_poses.append(get_keypoints_from_frame(frame, pose_estimator))\n",
|
|
"for frame in VideoReader(teacher_video):\n",
|
|
" teacher_poses.append(get_keypoints_from_frame(frame, pose_estimator))\n",
|
|
" \n",
|
|
"student_poses = np.concatenate(student_poses)\n",
|
|
"teacher_poses = np.concatenate(teacher_poses)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "38a8d7a5-17ed-4ce2-bb8b-d1637cb49578",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:55:09.342432Z",
|
|
"iopub.status.busy": "2023-07-05T12:55:09.342185Z",
|
|
"iopub.status.idle": "2023-07-05T12:55:09.350522Z",
|
|
"shell.execute_reply": "2023-07-05T12:55:09.350099Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:55:09.342416Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"valid_indices = np.array([0] + list(range(5, 17)))\n",
|
|
"\n",
|
|
"@torch.no_grad()\n",
|
|
"def _calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray):\n",
|
|
"\n",
|
|
" stu_kpts = torch.from_numpy(stu_kpts[:, None, valid_indices])\n",
|
|
" tch_kpts = torch.from_numpy(tch_kpts[None, :, valid_indices])\n",
|
|
" stu_kpts = stu_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],\n",
|
|
" stu_kpts.shape[2], 3)\n",
|
|
" tch_kpts = tch_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],\n",
|
|
" stu_kpts.shape[2], 3)\n",
|
|
"\n",
|
|
" matrix = torch.stack((stu_kpts, tch_kpts), dim=4)\n",
|
|
" if torch.cuda.is_available():\n",
|
|
" matrix = matrix.cuda()\n",
|
|
" # only consider visible keypoints\n",
|
|
" mask = torch.logical_and(matrix[:, :, :, 2, 0] > 0.3,\n",
|
|
" matrix[:, :, :, 2, 1] > 0.3)\n",
|
|
" matrix[~mask] = 0.0\n",
|
|
"\n",
|
|
" matrix_ = matrix.clone()\n",
|
|
" matrix_[matrix == 0] = 256\n",
|
|
" x_min = matrix_.narrow(3, 0, 1).min(dim=2).values\n",
|
|
" y_min = matrix_.narrow(3, 1, 1).min(dim=2).values\n",
|
|
" matrix_ = matrix.clone()\n",
|
|
" x_max = matrix_.narrow(3, 0, 1).max(dim=2).values\n",
|
|
" y_max = matrix_.narrow(3, 1, 1).max(dim=2).values\n",
|
|
"\n",
|
|
" matrix_ = matrix.clone()\n",
|
|
" matrix_[:, :, :, 0] = (matrix_[:, :, :, 0] - x_min) / (\n",
|
|
" x_max - x_min + 1e-4)\n",
|
|
" matrix_[:, :, :, 1] = (matrix_[:, :, :, 1] - y_min) / (\n",
|
|
" y_max - y_min + 1e-4)\n",
|
|
" matrix_[:, :, :, 2] = (matrix_[:, :, :, 2] > 0.3).float()\n",
|
|
" xy_dist = matrix_[..., :2, 0] - matrix_[..., :2, 1]\n",
|
|
" score = matrix_[..., 2, 0] * matrix_[..., 2, 1]\n",
|
|
"\n",
|
|
" similarity = (torch.exp(-50 * xy_dist.pow(2).sum(dim=-1)) *\n",
|
|
" score).sum(dim=-1) / (\n",
|
|
" score.sum(dim=-1) + 1e-6)\n",
|
|
" num_visible_kpts = score.sum(dim=-1)\n",
|
|
" similarity = similarity * torch.log(\n",
|
|
" (1 + (num_visible_kpts - 1) * 10).clamp(min=1)) / np.log(161)\n",
|
|
"\n",
|
|
" similarity[similarity.isnan()] = 0\n",
|
|
"\n",
|
|
" return similarity"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "658bcf89-df06-4c73-9323-8973a49c14c3",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:55:31.978675Z",
|
|
"iopub.status.busy": "2023-07-05T12:55:31.978219Z",
|
|
"iopub.status.idle": "2023-07-05T12:55:32.149624Z",
|
|
"shell.execute_reply": "2023-07-05T12:55:32.148568Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:55:31.978657Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# compute similarity without flip\n",
|
|
"similarity1 = _calculate_similarity(teacher_poses, student_poses)\n",
|
|
"\n",
|
|
"# compute similarity with flip\n",
|
|
"flip_indices = np.array(\n",
|
|
" [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15])\n",
|
|
"student_poses_flip = student_poses[:, flip_indices]\n",
|
|
"student_poses_flip[..., 0] = 191.5 - student_poses_flip[..., 0]\n",
|
|
"similarity2 = _calculate_similarity(teacher_poses, student_poses_flip)\n",
|
|
"\n",
|
|
"# select the larger similarity\n",
|
|
"similarity = torch.stack((similarity1, similarity2)).max(dim=0).values"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f981410d-4585-47c1-98c0-6946f948487d",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": false
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:55:57.321845Z",
|
|
"iopub.status.busy": "2023-07-05T12:55:57.321530Z",
|
|
"iopub.status.idle": "2023-07-05T12:55:57.582879Z",
|
|
"shell.execute_reply": "2023-07-05T12:55:57.582425Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:55:57.321826Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# visualize the similarity\n",
|
|
"plt.imshow(similarity.cpu().numpy())\n",
|
|
"\n",
|
|
"# there is an apparent diagonal in the figure\n",
|
|
"# we can select matched video snippets with this diagonal"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "13c189e5-fc53-46a2-9057-f0f2ffc1f46d",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:58:16.913855Z",
|
|
"iopub.status.busy": "2023-07-05T12:58:16.913529Z",
|
|
"iopub.status.idle": "2023-07-05T12:58:16.919972Z",
|
|
"shell.execute_reply": "2023-07-05T12:58:16.919005Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:58:16.913837Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@torch.no_grad()\n",
|
|
"def select_piece_from_similarity(similarity):\n",
|
|
" m, n = similarity.size()\n",
|
|
" row_indices = torch.arange(m).view(-1, 1).expand(m, n).to(similarity)\n",
|
|
" col_indices = torch.arange(n).view(1, -1).expand(m, n).to(similarity)\n",
|
|
" diagonal_indices = similarity.size(0) - 1 - row_indices + col_indices\n",
|
|
" unique_diagonal_indices, inverse_indices = torch.unique(\n",
|
|
" diagonal_indices, return_inverse=True)\n",
|
|
"\n",
|
|
" diagonal_sums_list = torch.zeros(\n",
|
|
" unique_diagonal_indices.size(0),\n",
|
|
" dtype=similarity.dtype,\n",
|
|
" device=similarity.device)\n",
|
|
" diagonal_sums_list.scatter_add_(0, inverse_indices.view(-1),\n",
|
|
" similarity.view(-1))\n",
|
|
" diagonal_sums_list[:min(m, n) // 4] = 0\n",
|
|
" diagonal_sums_list[-min(m, n) // 4:] = 0\n",
|
|
" index = diagonal_sums_list.argmax().item()\n",
|
|
"\n",
|
|
" similarity_smooth = torch.nn.functional.max_pool2d(\n",
|
|
" similarity[None], (1, 11), stride=(1, 1), padding=(0, 5))[0]\n",
|
|
" similarity_vec = similarity_smooth.diagonal(offset=index - m +\n",
|
|
" 1).cpu().numpy()\n",
|
|
"\n",
|
|
" stu_start = max(0, m - 1 - index)\n",
|
|
" tch_start = max(0, index - m + 1)\n",
|
|
"\n",
|
|
" return dict(\n",
|
|
" stu_start=stu_start,\n",
|
|
" tch_start=tch_start,\n",
|
|
" length=len(similarity_vec),\n",
|
|
" similarity=similarity_vec)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7c0e19df-949d-471d-804d-409b3b9ddf7d",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T12:58:44.860190Z",
|
|
"iopub.status.busy": "2023-07-05T12:58:44.859878Z",
|
|
"iopub.status.idle": "2023-07-05T12:58:44.888465Z",
|
|
"shell.execute_reply": "2023-07-05T12:58:44.887917Z",
|
|
"shell.execute_reply.started": "2023-07-05T12:58:44.860173Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"matched_piece_info = select_piece_from_similarity(similarity)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "51b0a2bd-253c-4a8f-a82a-263e18a4703e",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T13:01:19.061408Z",
|
|
"iopub.status.busy": "2023-07-05T13:01:19.060857Z",
|
|
"iopub.status.idle": "2023-07-05T13:01:19.293742Z",
|
|
"shell.execute_reply": "2023-07-05T13:01:19.293298Z",
|
|
"shell.execute_reply.started": "2023-07-05T13:01:19.061378Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.imshow(similarity.cpu().numpy())\n",
|
|
"plt.plot((matched_piece_info['tch_start'], \n",
|
|
" matched_piece_info['tch_start']+matched_piece_info['length']-1),\n",
|
|
" (matched_piece_info['stu_start'],\n",
|
|
" matched_piece_info['stu_start']+matched_piece_info['length']-1), 'r')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ffcde4e7-ff50-483a-b515-604c1d8f121a",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Generate Output Video"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "72171a0c-ab33-45bb-b84c-b15f0816ed3a",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T13:11:50.063595Z",
|
|
"iopub.status.busy": "2023-07-05T13:11:50.063259Z",
|
|
"iopub.status.idle": "2023-07-05T13:11:50.070929Z",
|
|
"shell.execute_reply": "2023-07-05T13:11:50.070411Z",
|
|
"shell.execute_reply.started": "2023-07-05T13:11:50.063574Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Tuple\n",
|
|
"\n",
|
|
"def resize_image_to_fixed_height(image: np.ndarray,\n",
|
|
" fixed_height: int) -> np.ndarray:\n",
|
|
" \"\"\"Resizes an input image to a specified fixed height while maintaining its\n",
|
|
" aspect ratio.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" image (np.ndarray): Input image as a numpy array [H, W, C]\n",
|
|
" fixed_height (int): Desired fixed height of the output image.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Resized image as a numpy array (fixed_height, new_width, channels).\n",
|
|
" \"\"\"\n",
|
|
" original_height, original_width = image.shape[:2]\n",
|
|
"\n",
|
|
" scale_ratio = fixed_height / original_height\n",
|
|
" new_width = int(original_width * scale_ratio)\n",
|
|
" resized_image = cv2.resize(image, (new_width, fixed_height))\n",
|
|
"\n",
|
|
" return resized_image\n",
|
|
"\n",
|
|
"def blend_images(img1: np.ndarray,\n",
|
|
" img2: np.ndarray,\n",
|
|
" blend_ratios: Tuple[float, float] = (1, 1)) -> np.ndarray:\n",
|
|
" \"\"\"Blends two input images with specified blend ratios.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" img1 (np.ndarray): First input image as a numpy array [H, W, C].\n",
|
|
" img2 (np.ndarray): Second input image as a numpy array [H, W, C]\n",
|
|
" blend_ratios (tuple): A tuple of two floats representing the blend\n",
|
|
" ratios for the two input images.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Blended image as a numpy array [H, W, C]\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def normalize_image(image: np.ndarray) -> np.ndarray:\n",
|
|
" if image.dtype == np.uint8:\n",
|
|
" return image.astype(np.float32) / 255.0\n",
|
|
" return image\n",
|
|
"\n",
|
|
" img1 = normalize_image(img1)\n",
|
|
" img2 = normalize_image(img2)\n",
|
|
"\n",
|
|
" blended_image = img1 * blend_ratios[0] + img2 * blend_ratios[1]\n",
|
|
" blended_image = blended_image.clip(min=0, max=1)\n",
|
|
" blended_image = (blended_image * 255).astype(np.uint8)\n",
|
|
"\n",
|
|
" return blended_image\n",
|
|
"\n",
|
|
"def get_smoothed_kpt(kpts, index, sigma=5):\n",
|
|
" \"\"\"Smooths keypoints using a Gaussian filter.\"\"\"\n",
|
|
" assert kpts.shape[1] == 17\n",
|
|
" assert kpts.shape[2] == 3\n",
|
|
" assert sigma % 2 == 1\n",
|
|
"\n",
|
|
" num_kpts = len(kpts)\n",
|
|
"\n",
|
|
" start_idx = max(0, index - sigma // 2)\n",
|
|
" end_idx = min(num_kpts, index + sigma // 2 + 1)\n",
|
|
"\n",
|
|
" # Extract a piece of the keypoints array to apply the filter\n",
|
|
" piece = kpts[start_idx:end_idx].copy()\n",
|
|
" original_kpt = kpts[index]\n",
|
|
"\n",
|
|
" # Split the piece into coordinates and scores\n",
|
|
" coords, scores = piece[..., :2], piece[..., 2]\n",
|
|
"\n",
|
|
" # Calculate the Gaussian ratio for each keypoint\n",
|
|
" gaussian_ratio = np.arange(len(scores)) + start_idx - index\n",
|
|
" gaussian_ratio = np.exp(-gaussian_ratio**2 / 2)\n",
|
|
"\n",
|
|
" # Update scores using the Gaussian ratio\n",
|
|
" scores *= gaussian_ratio[:, None]\n",
|
|
"\n",
|
|
" # Compute the smoothed coordinates\n",
|
|
" smoothed_coords = (coords * scores[..., None]).sum(axis=0) / (\n",
|
|
" scores[..., None].sum(axis=0) + 1e-4)\n",
|
|
"\n",
|
|
" original_kpt[..., :2] = smoothed_coords\n",
|
|
"\n",
|
|
" return original_kpt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "609b5adc-e176-4bf9-b9a4-506f72440017",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T13:12:46.198835Z",
|
|
"iopub.status.busy": "2023-07-05T13:12:46.198268Z",
|
|
"iopub.status.idle": "2023-07-05T13:12:46.202273Z",
|
|
"shell.execute_reply": "2023-07-05T13:12:46.200881Z",
|
|
"shell.execute_reply.started": "2023-07-05T13:12:46.198815Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"score, last_vis_score = 0, 0\n",
|
|
"video_writer = None\n",
|
|
"output_file = 'output.mp4'\n",
|
|
"stu_kpts = student_poses\n",
|
|
"tch_kpts = teacher_poses"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a264405a-5d50-49de-8637-2d1f67cb0a70",
|
|
"metadata": {
|
|
"ExecutionIndicator": {
|
|
"show": true
|
|
},
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T13:13:11.334760Z",
|
|
"iopub.status.busy": "2023-07-05T13:13:11.334433Z",
|
|
"iopub.status.idle": "2023-07-05T13:13:17.264181Z",
|
|
"shell.execute_reply": "2023-07-05T13:13:17.262931Z",
|
|
"shell.execute_reply.started": "2023-07-05T13:13:11.334742Z"
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from mmengine.structures import InstanceData\n",
|
|
"\n",
|
|
"tch_video_reader = VideoReader(teacher_video)\n",
|
|
"stu_video_reader = VideoReader(student_video)\n",
|
|
"for _ in range(matched_piece_info['tch_start']):\n",
|
|
" _ = next(tch_video_reader)\n",
|
|
"for _ in range(matched_piece_info['stu_start']):\n",
|
|
" _ = next(stu_video_reader)\n",
|
|
" \n",
|
|
"for i in track_iter_progress(range(matched_piece_info['length'])):\n",
|
|
" tch_frame = mmcv.bgr2rgb(next(tch_video_reader))\n",
|
|
" stu_frame = mmcv.bgr2rgb(next(stu_video_reader))\n",
|
|
" tch_frame = resize_image_to_fixed_height(tch_frame, 300)\n",
|
|
" stu_frame = resize_image_to_fixed_height(stu_frame, 300)\n",
|
|
"\n",
|
|
" stu_kpt = get_smoothed_kpt(stu_kpts, matched_piece_info['stu_start'] + i,\n",
|
|
" 5)\n",
|
|
" tch_kpt = get_smoothed_kpt(tch_kpts, matched_piece_info['tch_start'] + i,\n",
|
|
" 5)\n",
|
|
"\n",
|
|
" # draw pose\n",
|
|
" stu_kpt[..., 1] += (300 - 256)\n",
|
|
" tch_kpt[..., 0] += (256 - 192)\n",
|
|
" tch_kpt[..., 1] += (300 - 256)\n",
|
|
" stu_inst = InstanceData(\n",
|
|
" keypoints=stu_kpt[None, :, :2],\n",
|
|
" keypoint_scores=stu_kpt[None, :, 2])\n",
|
|
" tch_inst = InstanceData(\n",
|
|
" keypoints=tch_kpt[None, :, :2],\n",
|
|
" keypoint_scores=tch_kpt[None, :, 2])\n",
|
|
" \n",
|
|
" stu_out_img = pose_estimator.visualizer._draw_instances_kpts(\n",
|
|
" np.zeros((300, 256, 3)), stu_inst)\n",
|
|
" tch_out_img = pose_estimator.visualizer._draw_instances_kpts(\n",
|
|
" np.zeros((300, 256, 3)), tch_inst)\n",
|
|
" out_img = blend_images(\n",
|
|
" stu_out_img, tch_out_img, blend_ratios=(1, 0.3))\n",
|
|
"\n",
|
|
" # draw score\n",
|
|
" score_frame = matched_piece_info['similarity'][i]\n",
|
|
" score += score_frame * 1000\n",
|
|
" if score - last_vis_score > 1500:\n",
|
|
" last_vis_score = score\n",
|
|
" pose_estimator.visualizer.set_image(out_img)\n",
|
|
" pose_estimator.visualizer.draw_texts(\n",
|
|
" 'score: ', (60, 30),\n",
|
|
" font_sizes=15,\n",
|
|
" colors=(255, 255, 255),\n",
|
|
" vertical_alignments='bottom')\n",
|
|
" pose_estimator.visualizer.draw_texts(\n",
|
|
" f'{int(last_vis_score)}', (115, 30),\n",
|
|
" font_sizes=30 * max(0.4, score_frame),\n",
|
|
" colors=(255, 255, 255),\n",
|
|
" vertical_alignments='bottom')\n",
|
|
" out_img = pose_estimator.visualizer.get_image() \n",
|
|
" \n",
|
|
" # concatenate\n",
|
|
" concatenated_image = np.hstack((stu_frame, out_img, tch_frame))\n",
|
|
" if video_writer is None:\n",
|
|
" video_writer = cv2.VideoWriter(output_file,\n",
|
|
" cv2.VideoWriter_fourcc(*'mp4v'),\n",
|
|
" 30,\n",
|
|
" (concatenated_image.shape[1],\n",
|
|
" concatenated_image.shape[0]))\n",
|
|
" video_writer.write(mmcv.rgb2bgr(concatenated_image))\n",
|
|
"\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "745fdd75-6ed4-4cae-9f21-c2cd486ee918",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-07-05T13:13:18.704492Z",
|
|
"iopub.status.busy": "2023-07-05T13:13:18.704179Z",
|
|
"iopub.status.idle": "2023-07-05T13:13:18.714843Z",
|
|
"shell.execute_reply": "2023-07-05T13:13:18.713866Z",
|
|
"shell.execute_reply.started": "2023-07-05T13:13:18.704472Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"if video_writer is not None:\n",
|
|
" video_writer.release() "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7cb0bc99-ca19-44f1-bc0a-38e14afa980f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.15"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|