ClassificationImageTestTimeAugmentation
- class easypl.callbacks.predictors.image_classification.ClassificationImageTestTimeAugmentation(n: int, augmentations: List, augmentation_method: str = 'first', phase='val', reduce_method: Union[str, Callable] = 'mean')
Image classification callback for test-time-augmentation
- n
Number of augmentations.
- Type
int
- augmentations
List of augmentations, which will be used.
- Type
List
- augmentation_method
Method of selecting augmentations from list. Available: [“first”, “random”]
- Type
str
- phase
Phase which will be used by this predictor callback. Available: [“val”, “test”, “predict”].
- Type
str
- reduce_method
Method of result reducing
- Type
Union[str, Callable]
- augment(sample: Dict, augmentation) Dict
Method for augmentation apply.
- sample
Any sample of batch
- Type
Dict
- augmentation
Transform object
- Returns
Augmented sample
- Return type
Dict
- metric_formatting(outputs: Any, targets: Any) Tuple
Preparing before metric pass.
- outputs
Output from model
- Type
Any
- targets
Targets from batch
- Type
Any
- Returns
Formatted outputs and targets
- Return type
Tuple
- post_init(trainer: Trainer, pl_module: LightningModule)
Method for initialization in first batch handling. [NOT REQUIRED]
- trainer
Trainer of pytorch-lightning
- Type
pytorch_lightning.Trainer
- pl_module
LightningModule of pytorch-lightning
- Type
pytorch_lightning.LightningModule
- postprocessing(sample: Dict, dataloader_idx: int = 0) Dict
Method for postprocessing sample
- sample
Any sample of batch
- Type
Dict
- dataloader_idx
Index of dataloader
- Type
int
- Returns
Postprocessed sample
- Return type
Dict