diff --git a/spot_tools/src/spot_skills/detection_utils.py b/spot_tools/src/spot_skills/detection_utils.py index 59e6259..5e84744 100644 --- a/spot_tools/src/spot_skills/detection_utils.py +++ b/spot_tools/src/spot_skills/detection_utils.py @@ -1,10 +1,22 @@ from copy import copy +from dataclasses import dataclass +from typing import Optional import cv2 import numpy as np +from bosdyn.api import image_pb2 from ultralytics import YOLOE +@dataclass +class DetectionCandidate: + """A candidate image from a single camera source for object detection.""" + + bosdyn_image: image_pb2.ImageResponse + cv2_image: np.ndarray + detection_xy: Optional[tuple[int, int]] + + class Detector: def __init__(self, spot): self.spot = spot @@ -39,19 +51,35 @@ def set_up_detector(self, semantic_class): print(f"Updated recognized classes: {updated_classes}") def return_centroid(self, img_source, semantic_class, debug): - image, img = self.spot.get_image_RGB(view=img_source) - - xy = self._get_centroid(img, semantic_class, rotate=0, debug=debug) - if xy is None: - print("Object not found in first image. Looking around!") - xy, image, img, image_source = self._look_for_object( + detection_candidates = [] # List of DetectionCandidates -- one per camera source on Spot + detection_index = None + + # Get the primary image source & run the detector; Add to detection_candidates + primary_bosdyn_image, primary_cv2_image = self.spot.get_image_RGB( + view=img_source + ) + primary_xy = self._get_centroid( + primary_cv2_image, semantic_class, rotate=0, debug=debug + ) + detection_candidates.append( + DetectionCandidate(primary_bosdyn_image, primary_cv2_image, primary_xy) + ) + + # If primary image has a detection, just get the images from the other image sources; Otherwise, look for detections + if primary_xy is not None: + detection_index = 0 + secondary_candidates = self._get_candidate_detection_images() + else: + secondary_candidates, secondary_detection_index = self._look_for_object( semantic_class, debug=debug ) - - if xy is None: - print("Object not found near robot") - - return xy, image, img + if secondary_detection_index is not None: + detection_index = ( + secondary_detection_index + 1 + ) # offset of the primary candidate + # Collect all of the detection candidates + detection_candidates.extend(secondary_candidates) + return detection_index, detection_candidates def _get_centroid(self, img, semantic_class, rotate, debug): if rotate == 0: @@ -115,9 +143,25 @@ def _get_centroid(self, img, semantic_class, rotate, debug): else: return None + def _get_candidate_detection_images(self): + detection_candidates = [] + sources = self.spot.image_client.list_image_sources() + sources = [ + s for s in sources if "depth" not in s.name and "hand_image" != s.name + ] + for source in sources: + bosdyn_image, cv2_image = self.spot.get_image_RGB(view=source.name) + detection_candidates.append( + DetectionCandidate(bosdyn_image, cv2_image, None) + ) + return detection_candidates + def _look_for_object(self, semantic_class, debug): sources = self.spot.image_client.list_image_sources() + candidates = [] + detection_index = None + for source in sources: if ( "depth" in source.name or source.name == "hand_image" @@ -143,11 +187,12 @@ def _look_for_object(self, semantic_class, debug): print("Found object centroid:", xy) if xy is None: print(f"Object not found in {image_source}.") - continue - else: - return xy, image, img, image_source - return None, None, None, None + candidates.append(DetectionCandidate(image, img, xy)) + if xy is not None and detection_index is None: + detection_index = len(candidates) - 1 + + return candidates, detection_index class SemanticDetector(Detector): diff --git a/spot_tools/src/spot_skills/grasp_utils.py b/spot_tools/src/spot_skills/grasp_utils.py index 3376f51..ba26a60 100644 --- a/spot_tools/src/spot_skills/grasp_utils.py +++ b/spot_tools/src/spot_skills/grasp_utils.py @@ -30,7 +30,7 @@ open_gripper, stow_arm, ) -from spot_skills.detection_utils import Detector +from spot_skills.detection_utils import DetectionCandidate, Detector from spot_skills.primitives import execute_recovery_action g_image_click = None @@ -187,17 +187,18 @@ def object_grasp( # Set up the detector (e.g., for YOLOWorld, this may mean updating recognized classes) detector.set_up_detector(semantic_class) + candidates = None while attempts < 2 and not success: attempts += 1 if not user_input: # Try to get the centroid using the detector passed into the function. - xy, image, img = detector.return_centroid( + detection_index, candidates = detector.return_centroid( image_source, semantic_class, debug=debug ) # If the detector fails to return the centroid, then try again until max_attempts - if xy is None: + if detection_index is None: continue else: @@ -206,13 +207,15 @@ def object_grasp( else: image, img = spot.get_image_RGB(view=image_source) xy = get_user_grasp_input(spot, img) + candidates = [DetectionCandidate(image, img, xy)] + detection_index = 0 print("Found object centroid:", xy) - if xy is None: + if candidates is None: if feedback is not None: feedback.print( "INFO", - "Failed to find an object in any cameras after 2 attempts. Please check the detector or user input.", + "Failed to capture any camera images. Please check the detector or user input.", ) execute_recovery_action( spot, @@ -223,27 +226,47 @@ def object_grasp( ) time.sleep(1) return False - # execute_recovery_action(spot, recover_arm=True) - # spot.sit() - # raise Exception( - # "Failed to find an object in any cameras after 2 attempts. Please check the detector or user input." - # ) - # If xy is not None, then display the annotated image - else: - if feedback is not None: - annotated_img = copy(img) + # Display all candidate images in the approval panel + if feedback is not None: + cv2_images = [copy(c.cv2_image) for c in candidates] - response = feedback.bounding_box_detection_feedback( - annotated_img, - xy[0], - xy[1], - semantic_class, - ) + if detection_index is not None: + xy = candidates[detection_index].detection_xy + det_x = xy[0] if xy else None + det_y = xy[1] if xy else None + else: + det_x = None + det_y = None + + approved, updated_xy, selected_index = feedback.bounding_box_detection_feedback( + cv2_images, + detection_index, + det_x, + det_y, + semantic_class, + ) + + if approved is not None and not approved: + feedback.print("INFO", "User requested abort.") + return False - if response is not None and not response: - feedback.print("INFO", "User requested abort.") - return False + # Use selected camera image and pixel (panel always sends the correct selection) + image = candidates[selected_index].bosdyn_image + xy = updated_xy + else: + if detection_index is None: + execute_recovery_action( + spot, + recover_arm=False, + relative_pose=math_helpers.SE2Pose( + x=0.0, y=0.0, angle=np.random.choice([-0.5, 0.5]) + ), + ) + time.sleep(1) + return False + image = candidates[detection_index].bosdyn_image + xy = candidates[detection_index].detection_xy pick_vec = geometry_pb2.Vec2(x=xy[0], y=xy[1]) stow_arm(spot) diff --git a/spot_tools_ros/src/spot_tools_ros/spot_executor_ros.py b/spot_tools_ros/src/spot_tools_ros/spot_executor_ros.py index d8b3c9a..36df3d8 100755 --- a/spot_tools_ros/src/spot_tools_ros/spot_executor_ros.py +++ b/spot_tools_ros/src/spot_tools_ros/spot_executor_ros.py @@ -2,7 +2,6 @@ import threading import time -import cv2 import numpy as np import rclpy import rclpy.time @@ -12,6 +11,10 @@ from cv_bridge import CvBridge from heracles_ros_interfaces.srv import UpdateHoldingState from nav_msgs.msg import Path +from nlu_interface_rviz.msg import ( + ManipulationApprovalRequest, + ManipulationApprovalResponse, +) from rclpy.callback_groups import MutuallyExclusiveCallbackGroup from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node @@ -24,7 +27,7 @@ from spot_executor.fake_spot import FakeSpot from spot_executor.spot import Spot from spot_skills.detection_utils import YOLODetector -from std_msgs.msg import Bool, String +from std_msgs.msg import String from visualization_msgs.msg import Marker, MarkerArray from robot_executor_interface.mid_level_planner import ( @@ -77,7 +80,11 @@ def build_markers(pts, namespaces, frames, colors): class RosFeedbackCollector: def __init__(self, odom_frame: str, output_dir: str): self.pick_confirmation_event = threading.Event() - self.pick_confirmation_response = False + # self.pick_confirmation_response = False + + self.pick_confirmation_approved = False + self.pick_confirmation_xy = [0, 0] + self.pick_confirmation_image_index = 0 self.break_out_of_waiting_loop = False self.odom_frame = odom_frame @@ -85,34 +92,21 @@ def __init__(self, odom_frame: str, output_dir: str): self.output_dir = output_dir def bounding_box_detection_feedback( - self, annotated_img, centroid_x, centroid_y, semantic_class + self, detection_imgs, detection_index, centroid_x, centroid_y, semantic_class ): bridge = CvBridge() - if centroid_x is not None and centroid_y is not None: - label = f"{semantic_class}" - cv2.putText( - annotated_img, - label, - (centroid_x - 100, centroid_y - 200), - cv2.FONT_HERSHEY_SIMPLEX, - 5, - (0, 0, 225), - 20, - ) - - # Label the centroid - cv2.drawMarker( - annotated_img, - (centroid_x, centroid_y), - (0, 0, 255), - markerType=cv2.MARKER_TILTED_CROSS, - markerSize=200, - thickness=30, - ) - - annotated_img_msg = bridge.cv2_to_imgmsg(annotated_img, encoding="passthrough") - self.annotated_img_pub.publish(annotated_img_msg) + request_msg = ManipulationApprovalRequest() + request_msg.images = [ + bridge.cv2_to_imgmsg(img, encoding="passthrough") for img in detection_imgs + ] + request_msg.has_detection = detection_index is not None + request_msg.detection_image_index = ( + detection_index if detection_index is not None else 0 + ) + request_msg.image_x = centroid_x if centroid_x is not None else 0 + request_msg.image_y = centroid_y if centroid_y is not None else 0 + self.detection_img_pub.publish(request_msg) self.pick_confirmation_event.clear() @@ -126,10 +120,17 @@ def bounding_box_detection_feedback( if self.break_out_of_waiting_loop: self.logger.info("ROBOT WAS PREEMPTED") - self.pick_confirmation_response = False + self.pick_confirmation_approved = False + else: + self.logger.info( + f"Pick Confirmation Response Received: approved ({self.pick_confirmation_approved}), xy ({self.pick_confirmation_xy}), image_index ({self.pick_confirmation_image_index})" + ) - # This boolean determines whether the executor keeps going - return self.pick_confirmation_response + return ( + self.pick_confirmation_approved, + self.pick_confirmation_xy, + self.pick_confirmation_image_index, + ) def pick_image_feedback(self, semantic_image, mask_image): bridge = CvBridge() @@ -217,14 +218,16 @@ def register_publishers(self, node): MarkerArray, "~/mlp_target_publisher", qos_profile=latching_qos ) - self.annotated_img_pub = node.create_publisher( - Image, "~/annotated_image", qos_profile=latching_qos + self.detection_img_pub = node.create_publisher( + ManipulationApprovalRequest, + "~/manipulation_request", + qos_profile=latching_qos, ) self.lease_takeover_publisher = node.create_publisher(String, "~/takeover", 10) node.create_subscription( - Bool, + ManipulationApprovalResponse, "~/pick_confirmation", self.pick_confirmation_callback, 10, @@ -259,13 +262,26 @@ def set_robot_holding_state(self, is_holding: bool, object_id: str, timeout=5): return future.result().success def pick_confirmation_callback(self, msg): - if msg.data: - self.logger.info("Detection is valid. Continuing pick action!") - self.pick_confirmation_response = True - else: + # if msg.data: + # self.logger.info("Detection is valid. Continuing pick action!") + # self.pick_confirmation_response = True + # else: + # self.logger.warn("Detection is invalid. Discontinuing pick action.") + # self.pick_confirmation_response = False + + # self.pick_confirmation_event.set() + + # If not approved, discontinue + # If approved, check whether the detection is overwritten + if not msg.approve: self.logger.warn("Detection is invalid. Discontinuing pick action.") - self.pick_confirmation_response = False - + self.pick_confirmation_approved = False + else: + self.pick_confirmation_approved = True + self.pick_confirmation_image_index = msg.image_index + self.pick_confirmation_xy[0] = msg.image_x + self.pick_confirmation_xy[1] = msg.image_y + self.logger.warn("Detection is valid. Continuing pick action!") self.pick_confirmation_event.set() def log_lease_takeover(self, event: str):