|
14 | 14 | logger = logging.getLogger(__name__)
|
15 | 15 | try:
|
16 | 16 | import fiftyone as fo
|
| 17 | + import fiftyone.brain as fob |
17 | 18 | except ImportError:
|
18 | 19 | logger.debug("fiftyone module is not installed!")
|
19 | 20 |
|
20 |
| -__all__ = ["FiftyOneEvaluateDetections", "FiftyOnePredictionAdder"] |
| 21 | +__all__ = ["FiftyOneEvaluateDetections", "FiftyOneMistakenness", "FiftyOnePredictionAdder"] |
21 | 22 |
|
22 | 23 |
|
23 | 24 | class FiftyOneEvaluateDetections(Callback):
|
@@ -48,6 +49,25 @@ def on_predict_end(self, trainer, pl_module):
|
48 | 49 | results.print_report(classes=classes_top10)
|
49 | 50 |
|
50 | 51 |
|
| 52 | +class FiftyOneMistakenness(Callback): |
| 53 | + def __init__(self, run_id: str, gt_field: str = "ground_truth_detections") -> None: |
| 54 | + self.prediction_field = f"prediction_{run_id}" |
| 55 | + self.gt_field = gt_field |
| 56 | + |
| 57 | + def on_predict_start(self, trainer, pl_module): |
| 58 | + self.predict_dataset = trainer.datamodule.predict_dataset |
| 59 | + assert isinstance(self.predict_dataset, FiftyOneDataset) |
| 60 | + |
| 61 | + # reset mistakenness fields |
| 62 | + if self.predict_dataset.dataset.has_brain_run("mistakenness"): |
| 63 | + self.predict_dataset.dataset.delete_brain_run("mistakenness") |
| 64 | + |
| 65 | + def on_predict_end(self, trainer, pl_module): |
| 66 | + fob.compute_mistakenness( |
| 67 | + self.predict_dataset.filtered_dataset, self.prediction_field, label_field=self.gt_field |
| 68 | + ) |
| 69 | + |
| 70 | + |
51 | 71 | class FiftyOnePredictionAdder(BasePredictionWriter):
|
52 | 72 | def __init__(self, output_dir: str, write_interval: List[str]) -> None:
|
53 | 73 | super().__init__(write_interval)
|
|
0 commit comments