diff --git a/nodes.py b/nodes.py index f3809e2..4510b29 100644 --- a/nodes.py +++ b/nodes.py @@ -188,22 +188,25 @@ def process(self, model, images, width, height, retarget_image=None): key_frame_index_list = list(range(0, len(pose_metas), key_frame_step)) key_points_index = [0, 1, 2, 5, 8, 11, 10, 13] - + points_dict_list = [] for key_frame_index in key_frame_index_list: keypoints_body_list = [] body_key_points = pose_metas[key_frame_index]['keypoints_body'] for each_index in key_points_index: each_keypoint = body_key_points[each_index] - if None is each_keypoint: + if each_keypoint is None or len(each_keypoint) != 3: continue - keypoints_body_list.append(each_keypoint) - - keypoints_body = np.array(keypoints_body_list)[:, :2] - wh = np.array([[pose_metas[0]['width'], pose_metas[0]['height']]]) - points = (keypoints_body * wh).astype(np.int32) - points_dict_list = [] - for point in points: - points_dict_list.append({"x": int(point[0]), "y": int(point[1])}) + if each_keypoint[2] > 0.6: # only consider keypoints with confidence > 0.6 + keypoints_body_list.append(each_keypoint) + if len(keypoints_body_list) < 2: #too few keypoints, skip this frame + continue + else: + keypoints_body = np.array(keypoints_body_list)[:, :2] + wh = np.array([[pose_metas[key_frame_index]['width'], pose_metas[key_frame_index]['height']]]) + points = (keypoints_body * wh).astype(np.int32) + for point in points: + points_dict_list.append({"x": int(point[0]), "y": int(point[1])}) + break # only use the first valid key frame pose_data = { "retarget_image": refer_img if retarget_image is not None else None,