Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions dimos/agents2/skills/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,52 @@ def follow_person(self, person_description: str = "person", continuous: bool = T
# Always stop tracking
self._robot.person_tracker.stop_tracking()

@skill()
def follow_specific_person(self, person_id: int, continuous: bool = True) -> str:
"""Follow a specific person by their ReID.

Args:
person_id: The ReID of the person to track
continuous: If True, follow continuously without checking arrival.
If False, stop when person is reached (default: True)

Returns:
Status message indicating success or failure
"""
if not self._started:
raise ValueError(f"{self} has not been started.")

# Check for required modules
if not hasattr(self._robot, "person_tracker"):
return "Person tracker not available on this robot"

if not hasattr(self._robot, "reid_module"):
return "ReID module not available on this robot"

logger.info(f"Starting tracking of specific person ID {person_id}")

# Start tracking with the specific person ID
self._robot.person_tracker.start_tracking(continuous=continuous, target_person_id=person_id)

try:
start_time = time.time()
timeout = 60.0 # 60 second timeout

while time.time() - start_time < timeout:
# Check if tracking stopped (person reached or lost)
if not self._robot.person_tracker.is_tracking():
logger.info(f"Tracking stopped for person ID {person_id}")
return f"Finished tracking person ID {person_id}"

time.sleep(0.25)

logger.warning(f"Following person ID {person_id} timed out after {timeout}s")
return f"Timeout while following person ID {person_id}"

finally:
# Always stop tracking
self._robot.person_tracker.stop_tracking()

@skill()
def stop_human_tracking(self) -> str:
"""Stop tracking and following a person.
Expand Down
8 changes: 4 additions & 4 deletions dimos/models/embedding/treid.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __init__(
device: Device to run on (cuda/cpu), auto-detects if None
normalize: Whether to L2 normalize embeddings
"""
if not TORCHREID_AVAILABLE:
raise ImportError(
"torchreid is required for TorchReIDModel. Install it with: pip install torchreid"
)
# if not TORCHREID_AVAILABLE:
# raise ImportError(
# "torchreid is required for TorchReIDModel. Install it with: pip install torchreid"
# )

self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.normalize = normalize
Expand Down
66 changes: 57 additions & 9 deletions dimos/perception/detection/person_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Tuple
from typing import Optional, Tuple

from reactivex import operators as ops
from reactivex.observable import Observable
Expand Down Expand Up @@ -41,13 +41,20 @@ class PersonTracker(Module):

camera_info: CameraInfo

def __init__(self, cameraInfo: CameraInfo, arrival_threshold: float = 0.7, **kwargs):
def __init__(
self,
cameraInfo: CameraInfo,
arrival_threshold: float = 0.7,
target_person_id: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.camera_info = cameraInfo
self.tf = TF()
self._sub = None
self._is_tracking = False
self._continuous = True
self._target_person_id = target_person_id # Specific person ReID to track
self._arrival_threshold = arrival_threshold # bbox bottom must be in bottom 30% of frame

def center_to_3d(
Expand Down Expand Up @@ -119,18 +126,23 @@ def check_arrival(self, detection: Detection2DBBox) -> bool:
return is_arrived

@skill()
def start_tracking(self, continuous: bool = True):
def start_tracking(self, continuous: bool = True, target_person_id: Optional[int] = None):
"""Start person tracking.

Args:
continuous: If True, follow continuously without checking arrival.
If False, stop when person is reached (default: True)
target_person_id: Optional ReID of specific person to track.
If None, tracks the largest detected person.
"""
if not self._is_tracking:
self._continuous = continuous
self._target_person_id = target_person_id
self._sub = self.detections_stream().subscribe(self.track)
self._is_tracking = True
logger.info(f"PersonTracker: Tracking started (continuous={continuous})")
logger.info(
f"PersonTracker: Tracking started (continuous={continuous}, target_id={target_person_id})"
)
return "Person tracking started"

@skill()
Expand All @@ -147,6 +159,24 @@ def stop_tracking(self):
logger.info("PersonTracker: Tracking stopped")
return "Person tracking stopped"

@skill()
def set_target_person(self, person_id: Optional[int] = None):
"""Set the target person to track by ReID.

Args:
person_id: ReID of person to track, or None to track any person.

Returns:
Status message
"""
self._target_person_id = person_id
if person_id is not None:
logger.info(f"PersonTracker: Target set to person ID {person_id}")
return f"Now tracking person ID {person_id}"
else:
logger.info("PersonTracker: Target cleared, tracking any person")
return "Now tracking any visible person"

@rpc
def stop(self):
super().stop()
Expand All @@ -165,11 +195,29 @@ def track(self, detections2D: ImageDetections2D):
logger.warning("PersonTracker: No detections, skipping")
return

target = max(detections2D.detections, key=lambda det: det.bbox_2d_volume())
logger.info(
f"PersonTracker: Selected target person - center={target.center_bbox}, "
f"bbox_volume={target.bbox_2d_volume():.1f}px"
)
# Filter by target_person_id if specified
if self._target_person_id is not None:
# Filter detections by ReID (now in track_id field from enriched stream)
valid_detections = [
det for det in detections2D.detections if det.track_id == self._target_person_id
]

if not valid_detections:
logger.info(f"Target person {self._target_person_id} not in view")
return

target = valid_detections[0]
logger.info(
f"PersonTracker: Tracking specific person ID {self._target_person_id} - "
f"center={target.center_bbox}, bbox_volume={target.bbox_2d_volume():.1f}px"
)
else:
# Default behavior - track largest detection
target = max(detections2D.detections, key=lambda det: det.bbox_2d_volume())
logger.info(
f"PersonTracker: Selected largest person - center={target.center_bbox}, "
f"bbox_volume={target.bbox_2d_volume():.1f}px, track_id={target.track_id}"
)

if not self._continuous and self.check_arrival(target):
logger.info("Person reached, stopping tracker")
Expand Down
55 changes: 35 additions & 20 deletions dimos/perception/detection/reid/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from dimos.perception.detection.type import ImageDetections2D
from dimos.types.timestamped import align_timestamped, to_ros_stamp
from dimos.utils.reactive import backpressure
from dimos.utils.logging_config import setup_logger

logger = setup_logger(__name__)


class Config(ModuleConfig):
Expand All @@ -41,8 +44,14 @@ class ReidModule(Module):
detections: In[Detection2DArray] = None # type: ignore
image: In[Image] = None # type: ignore
annotations: Out[ImageAnnotations] = None # type: ignore
enriched_detections: Out[Detection2DArray] = None # type: ignore

def __init__(self, idsystem: IDSystem | None = None, **kwargs):
"""Initialize ReID module.

Args:
idsystem: ID system for tracking. Defaults to EmbeddingIDSystem with TorchReIDModel.
"""
super().__init__(**kwargs)
if idsystem is None:
try:
Expand All @@ -52,11 +61,13 @@ def __init__(self, idsystem: IDSystem | None = None, **kwargs):
except Exception as e:
raise RuntimeError(
"TorchReIDModel not available. Please install with: pip install dimos[torchreid]"
f"\n\nERROR: {e}"
) from e

self.idsystem = idsystem

def detections_stream(self) -> Observable[ImageDetections2D]:
"""Stream aligned image detections."""
return backpressure(
align_timestamped(
self.image.pure_observable(),
Expand All @@ -79,30 +90,34 @@ def stop(self):
def ingress(self, imageDetections: ImageDetections2D):
text_annotations = []

for detection in imageDetections:
# Register detection and get long-term ID
track_ids_in_frame = [det.track_id for det in imageDetections.detections]
if len(track_ids_in_frame) > 1:
self.idsystem.add_negative_constraints(track_ids_in_frame)

for detection in imageDetections.detections:
long_term_id = self.idsystem.register_detection(detection)

# Skip annotation if not ready yet (long_term_id == -1)
if long_term_id == -1:
continue

# Create text annotation for long_term_id above the detection
x1, y1, _, _ = detection.bbox
font_size = imageDetections.image.width / 60

text_annotations.append(
TextAnnotation(
timestamp=to_ros_stamp(detection.ts),
position=Point2(x=x1, y=y1 - font_size * 1.5),
text=f"PERSON: {long_term_id}",
font_size=font_size,
text_color=Color(r=0.0, g=1.0, b=1.0, a=1.0), # Cyan
background_color=Color(r=0.0, g=0.0, b=0.0, a=0.8),
# Override track_id with ReID for downstream processing
if long_term_id != -1:
detection.track_id = long_term_id

x1, y1, _, _ = detection.bbox
font_size = imageDetections.image.width / 60

text_annotations.append(
TextAnnotation(
timestamp=to_ros_stamp(detection.ts),
position=Point2(x=x1, y=y1 - font_size * 1.5),
text=f"PERSON: {long_term_id}",
font_size=font_size,
text_color=Color(r=0.0, g=1.0, b=1.0, a=1.0), # Cyan
background_color=Color(r=0.0, g=0.0, b=0.0, a=0.8),
)
)
)

# Publish annotations (even if empty to clear previous annotations)
if self.enriched_detections:
self.enriched_detections.publish(imageDetections.to_ros_detection2d_array())

annotations = ImageAnnotations(
texts=text_annotations,
texts_length=len(text_annotations),
Expand Down
28 changes: 24 additions & 4 deletions dimos/robot/unitree_webrtc/unitree_go2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from dimos.perception.object_tracker_2d import ObjectTracker2D
from dimos.perception.detection.module2D import Detection2DModule
from dimos.perception.detection.person_tracker import PersonTracker
from dimos.perception.detection.reid.module import ReidModule
from dimos.navigation.bbox_navigation import BBoxNavigationModule
from dimos_lcm.std_msgs import Bool
from dimos.robot.robot import UnitreeRobot
Expand Down Expand Up @@ -420,6 +421,7 @@ def __init__(
self.object_tracker = None
self.detection_module = None
self.person_tracker = None
self.reid_module = None
self.utilization_module = None

self._setup_directories()
Expand Down Expand Up @@ -593,6 +595,9 @@ def _deploy_perception(self):
Detection2DModule, camera_info=ConnectionModule._camera_info(), max_freq=5
)

# Deploy ReID module for person identification
self.reid_module = self._dimos.deploy(ReidModule)

# Deploy PersonTracker for person following
self.person_tracker = self._dimos.deploy(
PersonTracker,
Expand Down Expand Up @@ -630,13 +635,21 @@ def _deploy_perception(self):
"/detected/image/2", Image
)

# Set up transports for reid module
self.reid_module.annotations.transport = core.LCMTransport(
"/reid/annotations", ImageAnnotations
)
self.reid_module.enriched_detections.transport = core.LCMTransport(
"/reid/detections", Detection2DArray
)

# Set up transports for person tracker
self.person_tracker.target.transport = core.LCMTransport("/person_path", Path)

# Set up transports for bbox navigator
# self.bbox_navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped)

logger.info("Object tracker, detection module, person tracker deployed")
logger.info("Object tracker, detection module, ReID module, person tracker deployed")

def _deploy_camera(self):
"""Deploy and configure the camera module."""
Expand All @@ -650,12 +663,19 @@ def _deploy_camera(self):
self.detection_module.image.connect(self.connection.color_image)
logger.info("Detection module connected to camera")

# Connect person tracker inputs
# Connect reid module inputs
if self.reid_module:
self.reid_module.image.connect(self.connection.color_image)
self.reid_module.detections.connect(self.detection_module.detections)
logger.info("ReID module connected to detection module")

# Connect person tracker inputs - now using enriched detections from ReID
if self.person_tracker:
self.person_tracker.image.connect(self.connection.color_image)
self.person_tracker.detections.connect(self.detection_module.detections)
# Use enriched detections with ReIDs instead of raw detections
self.person_tracker.detections.connect(self.reid_module.enriched_detections)
self.person_tracker.target.connect(self.local_planner.path)
logger.info("Person tracker connected to detection module and local planner")
logger.info("Person tracker connected to ReID module and local planner")

# Connect bbox navigator inputs
if self.bbox_navigator:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ cuda = [
# embedding models
"open_clip_torch>=3.0.0",
"torchreid==0.2.5",
"gdown" # Needed for TorchReID
]

dev = [
Expand Down