Predictors

class easypl.callbacks.predictors.base.BaseTestTimeAugmentation(n: int, augmentations: List, augmentation_method: str = 'first', phase='val')

Base callback for test-time-augmentation

n

Number of augmentations.

Type

int

augmentations

List of augmentations, which will be used.

Type

List

augmentation_method

Method of selecting augmentations from list. Available: [“first”, “random”]

Type

str

phase

Phase which will be used by this predictor callback. Available: [“val”, “test”, “predict”].

Type

str

augment(sample: Dict, augmentation) Dict

Abstract method for augmentation apply.

sample

Any sample of batch

Type

Dict

augmentation

Transform object

Returns

Augmented sample

Return type

Dict

metric_formatting(outputs: Any, targets: Any) Tuple

Preparing before metric pass. On default, return passed values.

outputs

Output from model

Type

Any

targets

Targets from batch

Type

Any

Returns

Formatted outputs and targets

Return type

Tuple

post_init(trainer: Trainer, pl_module: LightningModule)

Abstract method for initialization in first batch handling. [NOT REQUIRED]

trainer

Trainer of pytorch-lightning

Type

pytorch_lightning.Trainer

pl_module

LightningModule of pytorch-lightning

Type

pytorch_lightning.LightningModule

postprocessing(sample: Dict, dataloader_idx: int = 0) Dict

Abstract method for postprocessing sample

sample

Any sample of batch

Type

Dict

dataloader_idx

Index of dataloader

Type

int

Returns

Postprocessed sample

Return type

Dict

preprocessing(sample: Dict, dataloader_idx: int = 0) Dict

Abstract method for preprocessing sample

sample

Any sample of batch

Type

Dict

dataloader_idx

Index of dataloader

Type

int

Returns

Preprocessed sample

Return type

Dict

reduce(tensor: Tensor) Tensor

Abstract method for reducing of results.

tensor

Any tensor with size [batch_size X …]

Type

torch.Tensor

Returns

Reduced tensor

Return type

torch.Tensor