diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 13dc94572..e2ff5a93b 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -275,6 +275,52 @@ def follow_person(self, person_description: str = "person", continuous: bool = T # Always stop tracking self._robot.person_tracker.stop_tracking() + @skill() + def follow_specific_person(self, person_id: int, continuous: bool = True) -> str: + """Follow a specific person by their ReID. + + Args: + person_id: The ReID of the person to track + continuous: If True, follow continuously without checking arrival. + If False, stop when person is reached (default: True) + + Returns: + Status message indicating success or failure + """ + if not self._started: + raise ValueError(f"{self} has not been started.") + + # Check for required modules + if not hasattr(self._robot, "person_tracker"): + return "Person tracker not available on this robot" + + if not hasattr(self._robot, "reid_module"): + return "ReID module not available on this robot" + + logger.info(f"Starting tracking of specific person ID {person_id}") + + # Start tracking with the specific person ID + self._robot.person_tracker.start_tracking(continuous=continuous, target_person_id=person_id) + + try: + start_time = time.time() + timeout = 60.0 # 60 second timeout + + while time.time() - start_time < timeout: + # Check if tracking stopped (person reached or lost) + if not self._robot.person_tracker.is_tracking(): + logger.info(f"Tracking stopped for person ID {person_id}") + return f"Finished tracking person ID {person_id}" + + time.sleep(0.25) + + logger.warning(f"Following person ID {person_id} timed out after {timeout}s") + return f"Timeout while following person ID {person_id}" + + finally: + # Always stop tracking + self._robot.person_tracker.stop_tracking() + @skill() def stop_human_tracking(self) -> str: """Stop tracking and following a person. diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index bdd00627a..d383286df 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -46,10 +46,10 @@ def __init__( device: Device to run on (cuda/cpu), auto-detects if None normalize: Whether to L2 normalize embeddings """ - if not TORCHREID_AVAILABLE: - raise ImportError( - "torchreid is required for TorchReIDModel. Install it with: pip install torchreid" - ) + # if not TORCHREID_AVAILABLE: + # raise ImportError( + # "torchreid is required for TorchReIDModel. Install it with: pip install torchreid" + # ) self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.normalize = normalize diff --git a/dimos/perception/detection/person_tracker.py b/dimos/perception/detection/person_tracker.py index f4bde3b70..af27a00eb 100644 --- a/dimos/perception/detection/person_tracker.py +++ b/dimos/perception/detection/person_tracker.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Tuple +from typing import Optional, Tuple from reactivex import operators as ops from reactivex.observable import Observable @@ -41,13 +41,20 @@ class PersonTracker(Module): camera_info: CameraInfo - def __init__(self, cameraInfo: CameraInfo, arrival_threshold: float = 0.7, **kwargs): + def __init__( + self, + cameraInfo: CameraInfo, + arrival_threshold: float = 0.7, + target_person_id: Optional[int] = None, + **kwargs, + ): super().__init__(**kwargs) self.camera_info = cameraInfo self.tf = TF() self._sub = None self._is_tracking = False self._continuous = True + self._target_person_id = target_person_id # Specific person ReID to track self._arrival_threshold = arrival_threshold # bbox bottom must be in bottom 30% of frame def center_to_3d( @@ -119,18 +126,23 @@ def check_arrival(self, detection: Detection2DBBox) -> bool: return is_arrived @skill() - def start_tracking(self, continuous: bool = True): + def start_tracking(self, continuous: bool = True, target_person_id: Optional[int] = None): """Start person tracking. Args: continuous: If True, follow continuously without checking arrival. If False, stop when person is reached (default: True) + target_person_id: Optional ReID of specific person to track. + If None, tracks the largest detected person. """ if not self._is_tracking: self._continuous = continuous + self._target_person_id = target_person_id self._sub = self.detections_stream().subscribe(self.track) self._is_tracking = True - logger.info(f"PersonTracker: Tracking started (continuous={continuous})") + logger.info( + f"PersonTracker: Tracking started (continuous={continuous}, target_id={target_person_id})" + ) return "Person tracking started" @skill() @@ -147,6 +159,24 @@ def stop_tracking(self): logger.info("PersonTracker: Tracking stopped") return "Person tracking stopped" + @skill() + def set_target_person(self, person_id: Optional[int] = None): + """Set the target person to track by ReID. + + Args: + person_id: ReID of person to track, or None to track any person. + + Returns: + Status message + """ + self._target_person_id = person_id + if person_id is not None: + logger.info(f"PersonTracker: Target set to person ID {person_id}") + return f"Now tracking person ID {person_id}" + else: + logger.info("PersonTracker: Target cleared, tracking any person") + return "Now tracking any visible person" + @rpc def stop(self): super().stop() @@ -165,11 +195,29 @@ def track(self, detections2D: ImageDetections2D): logger.warning("PersonTracker: No detections, skipping") return - target = max(detections2D.detections, key=lambda det: det.bbox_2d_volume()) - logger.info( - f"PersonTracker: Selected target person - center={target.center_bbox}, " - f"bbox_volume={target.bbox_2d_volume():.1f}px" - ) + # Filter by target_person_id if specified + if self._target_person_id is not None: + # Filter detections by ReID (now in track_id field from enriched stream) + valid_detections = [ + det for det in detections2D.detections if det.track_id == self._target_person_id + ] + + if not valid_detections: + logger.info(f"Target person {self._target_person_id} not in view") + return + + target = valid_detections[0] + logger.info( + f"PersonTracker: Tracking specific person ID {self._target_person_id} - " + f"center={target.center_bbox}, bbox_volume={target.bbox_2d_volume():.1f}px" + ) + else: + # Default behavior - track largest detection + target = max(detections2D.detections, key=lambda det: det.bbox_2d_volume()) + logger.info( + f"PersonTracker: Selected largest person - center={target.center_bbox}, " + f"bbox_volume={target.bbox_2d_volume():.1f}px, track_id={target.track_id}" + ) if not self._continuous and self.check_arrival(target): logger.info("Person reached, stopping tracker") diff --git a/dimos/perception/detection/reid/module.py b/dimos/perception/detection/reid/module.py index 64769b103..65ad61c45 100644 --- a/dimos/perception/detection/reid/module.py +++ b/dimos/perception/detection/reid/module.py @@ -29,6 +29,9 @@ from dimos.perception.detection.type import ImageDetections2D from dimos.types.timestamped import align_timestamped, to_ros_stamp from dimos.utils.reactive import backpressure +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__name__) class Config(ModuleConfig): @@ -41,8 +44,14 @@ class ReidModule(Module): detections: In[Detection2DArray] = None # type: ignore image: In[Image] = None # type: ignore annotations: Out[ImageAnnotations] = None # type: ignore + enriched_detections: Out[Detection2DArray] = None # type: ignore def __init__(self, idsystem: IDSystem | None = None, **kwargs): + """Initialize ReID module. + + Args: + idsystem: ID system for tracking. Defaults to EmbeddingIDSystem with TorchReIDModel. + """ super().__init__(**kwargs) if idsystem is None: try: @@ -52,11 +61,13 @@ def __init__(self, idsystem: IDSystem | None = None, **kwargs): except Exception as e: raise RuntimeError( "TorchReIDModel not available. Please install with: pip install dimos[torchreid]" + f"\n\nERROR: {e}" ) from e self.idsystem = idsystem def detections_stream(self) -> Observable[ImageDetections2D]: + """Stream aligned image detections.""" return backpressure( align_timestamped( self.image.pure_observable(), @@ -79,30 +90,34 @@ def stop(self): def ingress(self, imageDetections: ImageDetections2D): text_annotations = [] - for detection in imageDetections: - # Register detection and get long-term ID + track_ids_in_frame = [det.track_id for det in imageDetections.detections] + if len(track_ids_in_frame) > 1: + self.idsystem.add_negative_constraints(track_ids_in_frame) + + for detection in imageDetections.detections: long_term_id = self.idsystem.register_detection(detection) - # Skip annotation if not ready yet (long_term_id == -1) - if long_term_id == -1: - continue - - # Create text annotation for long_term_id above the detection - x1, y1, _, _ = detection.bbox - font_size = imageDetections.image.width / 60 - - text_annotations.append( - TextAnnotation( - timestamp=to_ros_stamp(detection.ts), - position=Point2(x=x1, y=y1 - font_size * 1.5), - text=f"PERSON: {long_term_id}", - font_size=font_size, - text_color=Color(r=0.0, g=1.0, b=1.0, a=1.0), # Cyan - background_color=Color(r=0.0, g=0.0, b=0.0, a=0.8), + # Override track_id with ReID for downstream processing + if long_term_id != -1: + detection.track_id = long_term_id + + x1, y1, _, _ = detection.bbox + font_size = imageDetections.image.width / 60 + + text_annotations.append( + TextAnnotation( + timestamp=to_ros_stamp(detection.ts), + position=Point2(x=x1, y=y1 - font_size * 1.5), + text=f"PERSON: {long_term_id}", + font_size=font_size, + text_color=Color(r=0.0, g=1.0, b=1.0, a=1.0), # Cyan + background_color=Color(r=0.0, g=0.0, b=0.0, a=0.8), + ) ) - ) - # Publish annotations (even if empty to clear previous annotations) + if self.enriched_detections: + self.enriched_detections.publish(imageDetections.to_ros_detection2d_array()) + annotations = ImageAnnotations( texts=text_annotations, texts_length=len(text_annotations), diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 7b34dd966..404d8708c 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -67,6 +67,7 @@ from dimos.perception.object_tracker_2d import ObjectTracker2D from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.person_tracker import PersonTracker +from dimos.perception.detection.reid.module import ReidModule from dimos.navigation.bbox_navigation import BBoxNavigationModule from dimos_lcm.std_msgs import Bool from dimos.robot.robot import UnitreeRobot @@ -420,6 +421,7 @@ def __init__( self.object_tracker = None self.detection_module = None self.person_tracker = None + self.reid_module = None self.utilization_module = None self._setup_directories() @@ -593,6 +595,9 @@ def _deploy_perception(self): Detection2DModule, camera_info=ConnectionModule._camera_info(), max_freq=5 ) + # Deploy ReID module for person identification + self.reid_module = self._dimos.deploy(ReidModule) + # Deploy PersonTracker for person following self.person_tracker = self._dimos.deploy( PersonTracker, @@ -630,13 +635,21 @@ def _deploy_perception(self): "/detected/image/2", Image ) + # Set up transports for reid module + self.reid_module.annotations.transport = core.LCMTransport( + "/reid/annotations", ImageAnnotations + ) + self.reid_module.enriched_detections.transport = core.LCMTransport( + "/reid/detections", Detection2DArray + ) + # Set up transports for person tracker self.person_tracker.target.transport = core.LCMTransport("/person_path", Path) # Set up transports for bbox navigator # self.bbox_navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) - logger.info("Object tracker, detection module, person tracker deployed") + logger.info("Object tracker, detection module, ReID module, person tracker deployed") def _deploy_camera(self): """Deploy and configure the camera module.""" @@ -650,12 +663,19 @@ def _deploy_camera(self): self.detection_module.image.connect(self.connection.color_image) logger.info("Detection module connected to camera") - # Connect person tracker inputs + # Connect reid module inputs + if self.reid_module: + self.reid_module.image.connect(self.connection.color_image) + self.reid_module.detections.connect(self.detection_module.detections) + logger.info("ReID module connected to detection module") + + # Connect person tracker inputs - now using enriched detections from ReID if self.person_tracker: self.person_tracker.image.connect(self.connection.color_image) - self.person_tracker.detections.connect(self.detection_module.detections) + # Use enriched detections with ReIDs instead of raw detections + self.person_tracker.detections.connect(self.reid_module.enriched_detections) self.person_tracker.target.connect(self.local_planner.path) - logger.info("Person tracker connected to detection module and local planner") + logger.info("Person tracker connected to ReID module and local planner") # Connect bbox navigator inputs if self.bbox_navigator: diff --git a/pyproject.toml b/pyproject.toml index 7e978fe90..8b5dd3058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,6 +168,7 @@ cuda = [ # embedding models "open_clip_torch>=3.0.0", "torchreid==0.2.5", + "gdown" # Needed for TorchReID ] dev = [