Skip to content

Change datamodule input to [0,1] scale #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 103 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
5b46ddf
Initial implementation of adversarial training callback
dxoigmn Jun 28, 2023
484876c
Merge branch 'main' into adversarial_training_as_callback
dxoigmn Jun 29, 2023
849f071
Make callback work in CPU
dxoigmn Jun 29, 2023
feffbe3
style
dxoigmn Jun 29, 2023
5e6c272
Add and use on_after_batch_transfer hook
dxoigmn Jun 29, 2023
59003c2
Override on_after_batch_transfer in callback setup
dxoigmn Jun 30, 2023
8218f65
Restore on_after_batch_transfer hook in teardown
dxoigmn Jun 30, 2023
226664d
cleanup
dxoigmn Jun 30, 2023
d1187c9
Call original on_after_batch_transfer hook
dxoigmn Jun 30, 2023
9850a78
Move Adversary.attack into callback
dxoigmn Jun 30, 2023
6406485
Merge branch 'main' into generalized_adversary
mzweilin Jul 14, 2023
d71cb43
Use training sequence by default.
mzweilin Jul 14, 2023
83f017e
Assume a simple model interface in adversary.
mzweilin Jul 14, 2023
88a807b
Update configs.
mzweilin Jul 14, 2023
9dfd7f6
Log original gain on progress bar.
mzweilin Jul 14, 2023
2cd7ca8
Hide model in Adversary so that it's not tampered.
mzweilin Jul 14, 2023
e642474
Skip adversary if not defined.
mzweilin Jul 14, 2023
f222aef
Update adversarial experiment on COCO.
mzweilin Jul 14, 2023
1d9b182
Fix test.
mzweilin Jul 14, 2023
45b5bf0
Hide model in adversar's batch.
mzweilin Jul 14, 2023
9a2531f
Fix test.
mzweilin Jul 14, 2023
c5bf847
Revert "Make input a dictionary for multi-modal object detection (#95)"
mzweilin Jul 14, 2023
51720da
[0,255] -> [0,1]
mzweilin Jul 14, 2023
756e0f8
Add transform in untransform in Adversary.
mzweilin Jul 14, 2023
33cbe72
Make transform/untransform optional arguments.
mzweilin Jul 14, 2023
d49f411
Add tuple transform/untransform.
mzweilin Jul 14, 2023
9bd4c1f
Comment.
mzweilin Jul 15, 2023
6d03271
Fix image visualizer for object detection.
mzweilin Jul 15, 2023
b30a34f
Switch to [0,1] data in tests.
mzweilin Jul 15, 2023
6feab79
Merge branch 'main' into adversarial_training_as_callback
mzweilin Jul 15, 2023
6250195
Merge branch 'adversarial_training_as_callback' into generalized_adve…
mzweilin Jul 15, 2023
ca23e1a
Merge branch 'generalized_adversary' into range01
mzweilin Jul 15, 2023
0356993
Merge branch 'main' into adversarial_training_as_callback
mzweilin Jul 15, 2023
39ee397
Merge branch 'adversarial_training_as_callback' into generalized_adve…
mzweilin Jul 15, 2023
9042195
Merge branch 'generalized_adversary' into range01
mzweilin Jul 15, 2023
704fc01
Hide transform in untransform in this PR.
mzweilin Jul 15, 2023
7803c9b
Merge branch 'generalized_adversary' into range01
mzweilin Jul 15, 2023
5bb36e7
Revert changes in gain logging.
mzweilin Jul 15, 2023
a530ee3
Merge branch 'generalized_adversary' into range01
mzweilin Jul 15, 2023
7f56140
Fix normalization params for timm.
mzweilin Jul 17, 2023
c85182d
Add LitModular.attack_step() for adversary to run model.forward() to …
mzweilin Jul 18, 2023
c3aeb8d
Decouple untransfrom() from get_input_adv().
mzweilin Jul 18, 2023
7aa91ed
Merge branch 'generalized_adversary' into range01
mzweilin Jul 18, 2023
ad74a35
Move BatchConverter to Adversary; move transform/untransform to Batch…
mzweilin Jul 18, 2023
558fc7b
Add batch_converter in Adversary that supports transform/untransform.
mzweilin Jul 19, 2023
7ed32ac
Clean up.
mzweilin Jul 19, 2023
170fb85
Merge branch 'generalized_adversary' into range01
mzweilin Jul 19, 2023
744ebfd
Fix adv. visualizer.
mzweilin Jul 19, 2023
6e43496
Fix adv. visualizer.
mzweilin Jul 19, 2023
254e168
Fix visualizer test.
mzweilin Jul 19, 2023
33a977f
Merge branch 'generalized_adversary' into range01
mzweilin Jul 19, 2023
f4f7b16
Clean up.
mzweilin Jul 19, 2023
265886f
Fix test.
mzweilin Jul 19, 2023
46b5cb7
Make adversary optional in some phases.
mzweilin Jul 19, 2023
197016d
Merge branch 'main' into generalized_adversary
mzweilin Jul 21, 2023
18268f4
Merge branch 'generalized_adversary' into range01
mzweilin Jul 21, 2023
342fef2
Add target_transform in target_untransform.
mzweilin Jul 24, 2023
ea7cfce
Rename variable: target -> target_transformed
mzweilin Jul 27, 2023
49c7da7
Merge branch 'generalized_adversary' into range01
mzweilin Jul 27, 2023
bae2345
Fix monkey patch in adv_training callback.
mzweilin Jul 28, 2023
9f1d35d
Fix DictBatchConverter.
mzweilin Jul 28, 2023
2a8505b
Remove recursive adversarial training callback in Adversary.
mzweilin Jul 28, 2023
5d81324
Copy batch before transform().
mzweilin Jul 28, 2023
a562277
Merge branch 'generalized_adversary' into range01
mzweilin Jul 28, 2023
77de175
Copy original batch in batch_converter(), because tuple batch cannot …
mzweilin Jul 28, 2023
fb79aa8
Copy original batch in batch_converter(), because tuple batch cannot …
mzweilin Jul 28, 2023
d9d8474
Merge branch 'generalized_adversary' into range01
mzweilin Jul 28, 2023
0f51d80
Upgrade dependency to torchmetrics == 1.0.1 (#205)
mzweilin Jul 20, 2023
d068902
Move adversary out of the model sequence.
mzweilin Aug 1, 2023
e81a55a
Make an adversarial training/evaluation callback.
mzweilin Aug 1, 2023
1c716f1
Merge branch 'main' into adversarial_training_as_callback
mzweilin Aug 1, 2023
67544da
Merge branch 'adversarial_training_as_callback' into finish_adv_train…
mzweilin Aug 1, 2023
579d653
Remove stuff that is related to callback entry points.
mzweilin Aug 3, 2023
92f9fe2
Replace model wrapper with a configurable model_transform.
mzweilin Aug 7, 2023
44c9c4b
Add Adversary.batch_converter().
mzweilin Aug 7, 2023
87d41d6
Comment.
mzweilin Aug 7, 2023
878b6f0
Merge branch 'finish_adv_training_callback' into add_adv_batch_converter
mzweilin Aug 7, 2023
fe78536
Rename as train/val/test.
mzweilin Aug 24, 2023
c9a9900
Rename the callback to AdversaryConnector because we may not perform …
mzweilin Aug 24, 2023
79c127a
Rename config: adversarial_training -> adversary_connector.
mzweilin Aug 24, 2023
cabc413
Merge branch 'main' into finish_adv_training_callback
mzweilin Aug 24, 2023
90808d5
Merge branch 'finish_adv_training_callback' into add_adv_batch_converter
mzweilin Aug 24, 2023
59c3d66
Update comments.
mzweilin Aug 24, 2023
3e3bc2f
Merge branch 'main' into add_adv_batch_converter
mzweilin Aug 25, 2023
08e3df1
Remove model_transform, again.
mzweilin Aug 25, 2023
597a08e
Comment.
mzweilin Aug 25, 2023
68ec00f
Comments.
mzweilin Aug 25, 2023
8b1f953
Rename batch_converter as batch_c15n.
mzweilin Aug 28, 2023
df1734f
Replace the _transformed suffix with the _orig suffix, because _orig …
mzweilin Aug 28, 2023
35d3652
Merge branch 'add_adv_batch_converter' into range01
mzweilin Aug 28, 2023
2559665
Rename transforms in batch_c15n/transform.
mzweilin Aug 28, 2023
fb09575
Comment.
mzweilin Aug 28, 2023
0d34797
Merge branch 'main' into range01
mzweilin Aug 31, 2023
ad4b2d5
Fix configs.
mzweilin Aug 31, 2023
79608ff
Remove preprocessor in torchvision_object_detection.
mzweilin Aug 31, 2023
9ce1408
Merge branch 'main' into range01
mzweilin Sep 1, 2023
d09524a
Move configs/attack/batch_c15n to configs/batch_c15n
mzweilin Sep 1, 2023
2b0ef2b
Rename config files.
mzweilin Sep 1, 2023
d37a414
Update fiftyone configs.
mzweilin Sep 1, 2023
09f58f7
Merge branch 'main' into range01
mzweilin Jan 8, 2024
b67b071
Mock Adversary.forward().
mzweilin Jan 8, 2024
6a1eca0
Upgrade pre-commit flake8 to 6.0.0 to work with newer importlib_metad…
mzweilin Jan 8, 2024
9cfd739
Merge branch 'main' into range01
mzweilin Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mart/configs/batch_c15n/input_tensor_float01.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Scale pixel input in the range of [0,1] to [0,255] for Adversary.
defaults:
- tuple
- transform: times_255_and_round
- transform@untransform: divided_by_255
11 changes: 11 additions & 0 deletions mart/configs/batch_c15n/input_tuple_float01.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Scale pixel input in the range of [0,1] to [0,255] for Adversary.
defaults:
- tuple
- transform@transform.transforms: times_255_and_round
- transform@untransform.transforms: divided_by_255

transform:
_target_: mart.transforms.TupleTransforms

untransform:
_target_: mart.transforms.TupleTransforms
3 changes: 3 additions & 0 deletions mart/configs/batch_c15n/transform/divided_by_255.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: torchvision.transforms.Normalize
mean: 0
std: 255
13 changes: 13 additions & 0 deletions mart/configs/batch_c15n/transform/times_255_and_round.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: torchvision.transforms.Compose
transforms:
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
# Fix potential numeric error.
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255
20 changes: 0 additions & 20 deletions mart/configs/datamodule/cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ train_dataset:
padding_mode: reflect
- _target_: torchvision.transforms.RandomHorizontalFlip
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255
target_transform: null
download: true

Expand All @@ -38,16 +28,6 @@ val_dataset:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255
target_transform: null
download: true

Expand Down
30 changes: 0 additions & 30 deletions mart/configs/datamodule/coco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ train_dataset:
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.RandomHorizontalFlip
p: 0.5
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

val_dataset:
_target_: mart.datamodules.coco.CocoDetection
Expand All @@ -32,16 +22,6 @@ val_dataset:
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

test_dataset:
_target_: mart.datamodules.coco.CocoDetection
Expand All @@ -52,16 +32,6 @@ test_dataset:
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

num_workers: 2
collate_fn:
Expand Down
30 changes: 0 additions & 30 deletions mart/configs/datamodule/coco_perturbable_mask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@ train_dataset:
- _target_: mart.transforms.RandomHorizontalFlip
p: 0.5
- _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

val_dataset:
transforms:
Expand All @@ -28,16 +18,6 @@ val_dataset:
# ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable.
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

test_dataset:
transforms:
Expand All @@ -46,13 +26,3 @@ test_dataset:
# ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable.
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255
30 changes: 0 additions & 30 deletions mart/configs/datamodule/fiftyone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@ train_dataset:
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.RandomHorizontalFlip
p: 0.5
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

val_dataset:
_target_: mart.datamodules.fiftyone.FiftyOneDataset
Expand All @@ -36,16 +26,6 @@ val_dataset:
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

test_dataset:
_target_: mart.datamodules.fiftyone.FiftyOneDataset
Expand All @@ -58,16 +38,6 @@ test_dataset:
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

num_workers: 2
collate_fn:
Expand Down
30 changes: 0 additions & 30 deletions mart/configs/datamodule/fiftyone_perturbable_mask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@ train_dataset:
- _target_: mart.transforms.RandomHorizontalFlip
p: 0.5
- _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

val_dataset:
_target_: mart.datamodules.fiftyone.FiftyOneDataset
Expand All @@ -40,16 +30,6 @@ val_dataset:
# ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable.
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

test_dataset:
_target_: mart.datamodules.fiftyone.FiftyOneDataset
Expand All @@ -64,13 +44,3 @@ test_dataset:
# ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable.
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255
20 changes: 0 additions & 20 deletions mart/configs/datamodule/imagenet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,6 @@ train_dataset:
size: 224
- _target_: torchvision.transforms.RandomHorizontalFlip
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

val_dataset:
_target_: torchvision.datasets.ImageNet
Expand All @@ -34,16 +24,6 @@ val_dataset:
- _target_: torchvision.transforms.CenterCrop
size: 224
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

test_dataset: ${.val_dataset}

Expand Down
1 change: 1 addition & 0 deletions mart/configs/experiment/CIFAR10_CNN_Adv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
defaults:
- /attack@callbacks.adversary_connector.train_adversary: classification_fgsm_linf
- /attack@callbacks.adversary_connector.test_adversary: classification_pgd_linf
- override /batch_c15n@callbacks.adversary_connector.batch_c15n: input_tensor_float01
- override /datamodule: cifar10
- override /model: classifier_cifar10_cnn
- override /metric: accuracy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
defaults:
- COCO_TorchvisionFasterRCNN
- /attack@callbacks.adversary_connector.test_adversary: object_detection_mask_adversary
- override /batch_c15n@callbacks.adversary_connector.batch_c15n: input_tuple_float01
- override /datamodule: coco_perturbable_mask
- override /callbacks: [model_checkpoint, lr_monitor, adversary_connector]

Expand Down
6 changes: 3 additions & 3 deletions mart/configs/model/classifier_cifar10_cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:

modules:
preprocessor:
# Normalize [0, 255] input.
# Normalize [0, 1] input.
_target_: torchvision.transforms.Normalize
mean: [125.307, 122.961, 113.8575]
std: [51.5865, 50.847, 51.255]
mean: [0.4914, 0.4822, 0.4465]
std: [0.2023, 0.1994, 0.2010]
7 changes: 4 additions & 3 deletions mart/configs/model/classifier_timm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ defaults:

modules:
preprocessor:
# Convert [0, 255] input to [0, 1]
# Use normalization parameters from timm
# https://github.com/huggingface/pytorch-image-models/blob/3d05c0e86f2f4c57bf495468aa3f8a7d3487c986/timm/data/constants.py#L3
_target_: torchvision.transforms.Normalize
mean: 0
std: 255
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

logits:
_target_: timm.models.convnext.convnext_tiny
Expand Down
15 changes: 3 additions & 12 deletions mart/configs/model/torchvision_faster_rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@ training_step_log:

training_sequence:
seq010:
preprocessor: ["input"]
losses_and_detections: ["input", "target"]

seq020:
losses_and_detections: ["preprocessor", "target"]

seq030:
loss:
# Sum up the losses.
[
Expand All @@ -28,17 +25,11 @@ training_sequence:

validation_sequence:
seq010:
preprocessor: ["input"]

seq020:
losses_and_detections: ["preprocessor", "target"]
losses_and_detections: ["input", "target"]

test_sequence:
seq010:
preprocessor: ["input"]

seq020:
losses_and_detections: ["preprocessor", "target"]
losses_and_detections: ["input", "target"]

modules:
losses_and_detections:
Expand Down
7 changes: 0 additions & 7 deletions mart/configs/model/torchvision_object_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@ test_sequence: ???
output_preds_key: "losses_and_detections.eval"

modules:
preprocessor:
_target_: mart.transforms.TupleTransforms
transforms:
_target_: torchvision.transforms.Normalize
mean: 0
std: 255

losses_and_detections:
# Return losses in the training mode and predictions in the eval mode in one pass.
_target_: mart.models.DualMode
Expand Down
Loading