Skip to content

Commit 0d349c5

Browse files
committed
Add FiftyOneMistakenness callback
1 parent 95bba3e commit 0d349c5

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

mart/callbacks/fiftyone.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
logger = logging.getLogger(__name__)
1515
try:
1616
import fiftyone as fo
17+
import fiftyone.brain as fob
1718
except ImportError:
1819
logger.debug("fiftyone module is not installed!")
1920

20-
__all__ = ["FiftyOneEvaluateDetections", "FiftyOnePredictionAdder"]
21+
__all__ = ["FiftyOneEvaluateDetections", "FiftyOneMistakenness", "FiftyOnePredictionAdder"]
2122

2223

2324
class FiftyOneEvaluateDetections(Callback):
@@ -48,6 +49,25 @@ def on_predict_end(self, trainer, pl_module):
4849
results.print_report(classes=classes_top10)
4950

5051

52+
class FiftyOneMistakenness(Callback):
53+
def __init__(self, run_id: str, gt_field: str = "ground_truth_detections") -> None:
54+
self.prediction_field = f"prediction_{run_id}"
55+
self.gt_field = gt_field
56+
57+
def on_predict_start(self, trainer, pl_module):
58+
self.predict_dataset = trainer.datamodule.predict_dataset
59+
assert isinstance(self.predict_dataset, FiftyOneDataset)
60+
61+
# reset mistakenness fields
62+
if self.predict_dataset.dataset.has_brain_run("mistakenness"):
63+
self.predict_dataset.dataset.delete_brain_run("mistakenness")
64+
65+
def on_predict_end(self, trainer, pl_module):
66+
fob.compute_mistakenness(
67+
self.predict_dataset.filtered_dataset, self.prediction_field, label_field=self.gt_field
68+
)
69+
70+
5171
class FiftyOnePredictionAdder(BasePredictionWriter):
5272
def __init__(self, output_dir: str, write_interval: List[str]) -> None:
5373
super().__init__(write_interval)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
fiftyone_mistakenness:
2+
_target_: mart.callbacks.FiftyOneMistakenness
3+
run_id: ${now:%Y-%m-%d}_${now:%H-%M-%S}

mart/configs/experiment/FiftyOne_TorchvisionFasterRCNN.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ defaults:
99
lr_monitor,
1010
fiftyone_prediction_adder,
1111
fiftyone_evaluate_detections,
12+
fiftyone_mistakenness,
1213
]
1314

1415
task_name: "FiftyOne_TorchvisionFasterRCNN"

0 commit comments

Comments
 (0)