diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8236ab1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +pyquaternion==0.9.9 +nusceneis-devkit==1.1.11 +opencv-python==4.8.0.74 +pytorch-lightning==2.4.0 +fvcore==0.1.5.post20221221 +efficientnet_pytorch==0.7.1 +timm==1.0.8 +scikit_image==0.24.0 diff --git a/stp3/metrics.py b/stp3/metrics.py index 2829f44..a8b0a38 100644 --- a/stp3/metrics.py +++ b/stp3/metrics.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn import numpy as np -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes -from pytorch_lightning.metrics.functional.reduction import reduce +from torchmetrics import Metric +from torchmetrics.functional import stat_scores +from torchmetrics.utilities import reduce from skimage.draw import polygon from stp3.utils.tools import gen_dx_bx @@ -19,30 +19,38 @@ def __init__( n_classes: int, ignore_index: Optional[int] = None, absent_score: float = 0.0, - reduction: str = 'none', - compute_on_step: bool = False, + reduction: str = 'none' ): - super().__init__(compute_on_step=compute_on_step) + super().__init__() self.n_classes = n_classes self.ignore_index = ignore_index self.absent_score = absent_score self.reduction = reduction + # Initialize states for the metric computation self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum') def update(self, prediction: torch.Tensor, target: torch.Tensor): - tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes) + # Calculate statistics for each class + tps, fps, _, fns, sups = stat_scores( + preds=prediction, + target=target, + average=None, + num_classes=self.n_classes + ) + # Update state variables self.true_positive += tps self.false_positive += fps self.false_negative += fns self.support += sups def compute(self): + # Initialize scores tensor scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32) for class_idx in range(self.n_classes): @@ -54,20 +62,21 @@ def compute(self): fn = self.false_negative[class_idx] sup = self.support[class_idx] - # If this class is absent in the target (no support) AND absent in the pred (no true or false - # positives), then use the absent_score for this class. + # Assign absent_score if the class is absent in both target and prediction if sup + tp + fp == 0: scores[class_idx] = self.absent_score continue + # Calculate IoU score denominator = tp + fp + fn score = tp.to(torch.float) / denominator scores[class_idx] = score - # Remove the ignored class index from the scores. + # Exclude the ignored class index from scores if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes): - scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index+1:]]) + scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index + 1:]]) + # Reduce scores according to the specified reduction method return reduce(scores, reduction=self.reduction) @@ -76,22 +85,22 @@ def __init__( self, n_classes: int, temporally_consistent: bool = True, - vehicles_id: int = 1, - compute_on_step: bool = False, + vehicles_id: int = 1 ): - super().__init__(compute_on_step=compute_on_step) + super().__init__() self.n_classes = n_classes self.temporally_consistent = temporally_consistent self.vehicles_id = vehicles_id self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative'] + # Initialize states for the metric computation self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum') self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') - def update(self, pred_instance, gt_instance): + def update(self, pred_instance: torch.Tensor, gt_instance: torch.Tensor): """ Update state with predictions and targets. @@ -133,12 +142,7 @@ def compute(self): sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive)) rq = self.true_positive / denominator - return {'pq': pq, - 'sq': sq, - 'rq': rq, - # If 0, it means there wasn't any detection. - # 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2), - } + return {'pq': pq, 'sq': sq, 'rq': rq} def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping): """ @@ -163,44 +167,31 @@ def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt n_all_things = n_instances + n_classes # Classes + instances. n_things_and_void = n_all_things + 1 - # Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with - # instances, it is not present. - # and the rest are instance ids starting from 3 prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things) target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things) - # Compute ious between all stuff and things - # hack for bincounting 2 arrays together x = prediction + n_things_and_void * target bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2) if bincount_2d.shape[0] != n_things_and_void ** 2: raise ValueError('Incorrect bincount size.') conf = bincount_2d.reshape((n_things_and_void, n_things_and_void)) - # Drop void class conf = conf[1:, 1:] - # Confusion matrix contains intersections between all combinations of classes union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float()) - # In the iou matrix, first dimension is target idx, second dimension is pred idx. - # Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou. mapping = (iou > 0.5).nonzero(as_tuple=False) - # Check that classes match. is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]] mapping = mapping[is_matching] tp_mask = torch.zeros_like(conf, dtype=torch.bool) tp_mask[mapping[:, 0], mapping[:, 1]] = True - # First ids correspond to "stuff" i.e. semantic seg. - # Instance ids are offset accordingly for target_id, pred_id in mapping: cls_id = pred_to_cls[pred_id] if self.temporally_consistent and cls_id == self.vehicles_id: if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item(): - # Not temporally consistent result['false_negative'][target_to_cls[target_id]] += 1 result['false_positive'][pred_to_cls[pred_id]] += 1 unique_id_mapping[target_id.item()] = pred_id.item() @@ -211,18 +202,14 @@ def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt unique_id_mapping[target_id.item()] = pred_id.item() for target_id in range(n_classes, n_all_things): - # If this is a true positive do nothing. if tp_mask[target_id, n_classes:].any(): continue - # If this target instance didn't match with any predictions and was present set it as false negative. if target_to_cls[target_id] != -1: result['false_negative'][target_to_cls[target_id]] += 1 for pred_id in range(n_classes, n_all_things): - # If this is a true positive do nothing. if tp_mask[n_classes:, pred_id].any(): continue - # If this predicted instance didn't match with any prediction, set that predictions as false positive. if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any(): result['false_positive'][pred_to_cls[pred_id]] += 1 @@ -238,9 +225,8 @@ def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_cla instance = instance - 1 + n_classes segmentation = segmentation.clone().view(-1) - segmentation_mask = segmentation < n_classes # Remove void pixels. + segmentation_mask = segmentation < n_classes - # Build an index from instance id to class id. instance_id_to_class_tuples = torch.cat( ( instance[instance_mask & segmentation_mask].unsqueeze(1), @@ -255,69 +241,89 @@ def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_cla ) segmentation[instance_mask] = instance[instance_mask] - segmentation += 1 # Shift all legit classes by 1. - segmentation[~segmentation_mask] = 0 # Shift void class to zero. + segmentation += 1 + segmentation[~segmentation_mask] = 0 return segmentation, instance_id_to_class + class PlanningMetric(Metric): def __init__( self, cfg, - n_future=4, - compute_on_step: bool = False, + n_future=4 ): - super().__init__(compute_on_step=compute_on_step) + super().__init__() + + # Generate grid dx, bx parameters dx, bx, _ = gen_dx_bx(cfg.LIFT.X_BOUND, cfg.LIFT.Y_BOUND, cfg.LIFT.Z_BOUND) dx, bx = dx[:2], bx[:2] + + # Set parameters as nn.Parameter to keep them immutable during training self.dx = nn.Parameter(dx, requires_grad=False) self.bx = nn.Parameter(bx, requires_grad=False) + # Calculate bird's eye view dimensions _, _, self.bev_dimension = calculate_birds_eye_view_parameters( cfg.LIFT.X_BOUND, cfg.LIFT.Y_BOUND, cfg.LIFT.Z_BOUND ) self.bev_dimension = self.bev_dimension.numpy() + # Ego vehicle dimensions self.W = cfg.EGO.WIDTH self.H = cfg.EGO.HEIGHT + # Number of future time steps to evaluate self.n_future = n_future + # Initialize metric states self.add_state("obj_col", default=torch.zeros(self.n_future), dist_reduce_fx="sum") self.add_state("obj_box_col", default=torch.zeros(self.n_future), dist_reduce_fx="sum") - self.add_state("L2", default=torch.zeros(self.n_future),dist_reduce_fx="sum") + self.add_state("L2", default=torch.zeros(self.n_future), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def evaluate_single_coll(self, traj, segmentation): ''' - gt_segmentation - traj: torch.Tensor (n_future, 2) - segmentation: torch.Tensor (n_future, 200, 200) + Evaluate collision for a single trajectory against segmentation. + + Parameters: + traj: torch.Tensor (n_future, 2) + segmentation: torch.Tensor (n_future, 200, 200) + + Returns: + collision: torch.Tensor (n_future,) indicating collision at each time step ''' + # Define polygon representing the vehicle's bounding box pts = np.array([ [-self.H / 2. + 0.5, self.W / 2.], [self.H / 2. + 0.5, self.W / 2.], [self.H / 2. + 0.5, -self.W / 2.], [-self.H / 2. + 0.5, -self.W / 2.], ]) + + # Transform vehicle coordinates into BEV grid coordinates pts = (pts - self.bx.cpu().numpy()) / (self.dx.cpu().numpy()) pts[:, [0, 1]] = pts[:, [1, 0]] + + # Generate polygon in BEV grid rr, cc = polygon(pts[:,1], pts[:,0]) rc = np.concatenate([rr[:,None], cc[:,None]], axis=-1) + # Adjust trajectory to grid n_future, _ = traj.shape trajs = traj.view(n_future, 1, 2) - trajs[:,:,[0,1]] = trajs[:,:,[1,0]] # can also change original tensor + trajs[:,:,[0,1]] = trajs[:,:,[1,0]] # Swap x, y axes trajs = trajs / self.dx - trajs = trajs.cpu().numpy() + rc # (n_future, 32, 2) + trajs = trajs.cpu().numpy() + rc + # Clip coordinates to valid range r = trajs[:,:,0].astype(np.int32) r = np.clip(r, 0, self.bev_dimension[0] - 1) c = trajs[:,:,1].astype(np.int32) c = np.clip(c, 0, self.bev_dimension[1] - 1) + # Check collision at each future time step collision = np.full(n_future, False) for t in range(n_future): rr = r[t] @@ -332,11 +338,20 @@ def evaluate_single_coll(self, traj, segmentation): def evaluate_coll(self, trajs, gt_trajs, segmentation): ''' - trajs: torch.Tensor (B, n_future, 2) - gt_trajs: torch.Tensor (B, n_future, 2) - segmentation: torch.Tensor (B, n_future, 200, 200) + Evaluate collision for batch of trajectories against segmentation. + + Parameters: + trajs: torch.Tensor (B, n_future, 2) + gt_trajs: torch.Tensor (B, n_future, 2) + segmentation: torch.Tensor (B, n_future, 200, 200) + + Returns: + obj_coll_sum: torch.Tensor (n_future,) total collisions with objects + obj_box_coll_sum: torch.Tensor (n_future,) total box collisions ''' B, n_future, _ = trajs.shape + + # Adjust trajectories to account for coordinate system differences trajs = trajs * torch.tensor([-1, 1], device=trajs.device) gt_trajs = gt_trajs * torch.tensor([-1, 1], device=gt_trajs.device) @@ -367,17 +382,25 @@ def evaluate_coll(self, trajs, gt_trajs, segmentation): def compute_L2(self, trajs, gt_trajs): ''' - trajs: torch.Tensor (B, n_future, 3) - gt_trajs: torch.Tensor (B, n_future, 3) + Compute L2 distance between predicted and ground truth trajectories. + + Parameters: + trajs: torch.Tensor (B, n_future, 3) + gt_trajs: torch.Tensor (B, n_future, 3) + + Returns: + L2: torch.Tensor (B, n_future) L2 distances at each time step ''' - return torch.sqrt(((trajs[:, :, :2] - gt_trajs[:, :, :2]) ** 2).sum(dim=-1)) def update(self, trajs, gt_trajs, segmentation): ''' - trajs: torch.Tensor (B, n_future, 3) - gt_trajs: torch.Tensor (B, n_future, 3) - segmentation: torch.Tensor (B, n_future, 200, 200) + Update metric states with batch of predictions and ground truths. + + Parameters: + trajs: torch.Tensor (B, n_future, 3) + gt_trajs: torch.Tensor (B, n_future, 3) + segmentation: torch.Tensor (B, n_future, 200, 200) ''' assert trajs.shape == gt_trajs.shape L2 = self.compute_L2(trajs, gt_trajs) @@ -386,11 +409,17 @@ def update(self, trajs, gt_trajs, segmentation): self.obj_col += obj_coll_sum self.obj_box_col += obj_box_coll_sum self.L2 += L2.sum(dim=0) - self.total +=len(trajs) + self.total += len(trajs) def compute(self): + ''' + Compute final metric results after aggregation. + + Returns: + dict with keys 'obj_col', 'obj_box_col', and 'L2' + ''' return { 'obj_col': self.obj_col / self.total, 'obj_box_col': self.obj_box_col / self.total, - 'L2' : self.L2 / self.total - } \ No newline at end of file + 'L2': self.L2 / self.total + } diff --git a/stp3/models/encoder.py b/stp3/models/encoder.py index a8c7c48..9ccb2ce 100644 --- a/stp3/models/encoder.py +++ b/stp3/models/encoder.py @@ -25,7 +25,7 @@ def __init__(self, cfg, D): raise NotImplementedError self.upsampling_out_channel = [0, 48, 64, 128, 512] - index = np.log2(self.downsample).astype(np.int) + index = np.log2(self.downsample).astype(int) if self.use_depth_distribution: self.depth_layer_1 = DeepLabHead(self.reduction_channel[index+1], self.reduction_channel[index+1], hidden_channel=64) @@ -81,7 +81,7 @@ def get_features_depth(self, x): # Head endpoints['reduction_{}'.format(len(endpoints) + 1)] = x - index = np.log2(self.downsample).astype(np.int) + index = np.log2(self.downsample).astype(int) input_1 = endpoints['reduction_{}'.format(index + 1)] input_2 = endpoints['reduction_{}'.format(index)] diff --git a/stp3/trainer.py b/stp3/trainer.py index 0464bbd..0ded577 100644 --- a/stp3/trainer.py +++ b/stp3/trainer.py @@ -16,7 +16,7 @@ def __init__(self, hparams): super().__init__() # see config.py for details - self.hparams = hparams + self.save_hyperparameters(hparams) # pytorch lightning does not support saving YACS CfgNone cfg = get_cfg(cfg_dict=self.hparams) self.cfg = cfg