ClassificationImageTestTimeAugmentation

class easypl.callbacks.predictors.image_classification.ClassificationImageTestTimeAugmentation(n: int, augmentations: List, augmentation_method: str = 'first', phase='val', reduce_method: Union[str, Callable] = 'mean')

Image classification callback for test-time-augmentation

n

Number of augmentations.

Type

int

augmentations

List of augmentations, which will be used.

Type

List

augmentation_method

Method of selecting augmentations from list. Available: [“first”, “random”]

Type

str

phase

Phase which will be used by this predictor callback. Available: [“val”, “test”, “predict”].

Type

str

reduce_method

Method of result reducing

Type

Union[str, Callable]

augment(sample: Dict, augmentation) Dict

Method for augmentation apply.

sample

Any sample of batch

Type

Dict

augmentation

Transform object

Returns

Augmented sample

Return type

Dict

metric_formatting(outputs: Any, targets: Any) Tuple

Preparing before metric pass.

outputs

Output from model

Type

Any

targets

Targets from batch

Type

Any

Returns

Formatted outputs and targets

Return type

Tuple

post_init(trainer: Trainer, pl_module: LightningModule)

Method for initialization in first batch handling. [NOT REQUIRED]

trainer

Trainer of pytorch-lightning

Type

pytorch_lightning.Trainer

pl_module

LightningModule of pytorch-lightning

Type

pytorch_lightning.LightningModule

postprocessing(sample: Dict, dataloader_idx: int = 0) Dict

Method for postprocessing sample

sample

Any sample of batch

Type

Dict

dataloader_idx

Index of dataloader

Type

int

Returns

Postprocessed sample

Return type

Dict

preprocessing(sample: Dict, dataloader_idx: int = 0) Dict

Method for preprocessing sample

sample

Any sample of batch

Type

Dict

dataloader_idx

Index of dataloader

Type

int

Returns

Preprocessed sample

Return type

Dict

reduce(tensor: Tensor) Tensor

Method for reducing of results.

tensor

Any tensor with size [batch_size X …]

Type

torch.Tensor

Returns

Reduced tensor

Return type

torch.Tensor