Learners

class easypl.learners.base.BaseLearner(model: Optional[Union[Module, List[Module]]] = None, loss: Optional[Union[Module, List[Module]]] = None, optimizer: Optional[Union[WrapperOptimizer, List[WrapperOptimizer]]] = None, lr_scheduler: Optional[Union[WrapperScheduler, List[WrapperScheduler]]] = None, train_metrics: Optional[List[Metric]] = None, val_metrics: Optional[List[Metric]] = None, test_metrics: Optional[List[Metric]] = None, data_keys: Optional[List[str]] = None, target_keys: Optional[List[str]] = None)

Abstract base learner

model

torch.nn.Module model.

Type:

Optional[Union[torch.nn.Module, List[torch.nn.Module]]]

loss

torch.nn.Module loss function.

Type:

Optional[Union[torch.nn.Module, List[torch.nn.Module]]]

optimizer

Optimizer wrapper object.

Type:

Optional[Union[WrapperOptimizer, List[WrapperOptimizer]]]

lr_scheduler

Scheduler object for lr scheduling.

Type:

Optional[Union[WrapperScheduler, List[WrapperScheduler]]]

train_metrics

List of train metrics.

Type:

Optional[List[Metric]]

val_metrics

List of validation metrics.

Type:

Optional[List[Metric]]

test_metrics

List of test metrics.

Type:

Optional[List[Metric]]

data_keys

List of data keys

Type:

Optional[List[str]]

target_keys

List of target keys

Type:

Optional[List[str]]

get_outputs(batch: Dict, optimizer_idx: int = 0) Dict

Abtract method for selecting and preprocessing outputs from batch

batch

Batch in step

Type:

Dict

Returns:

Dict with keys: [“loss”, “metric”, “log”]

Return type:

Dict

get_targets(batch: Dict, optimizer_idx: int = 0) Dict

Abtract method for selecting and preprocessing targets from batch

batch

Batch in step

Type:

Dict

Returns:

Dict with keys: [“loss”, “metric”, “log”]

Return type:

Dict

loss_step(outputs: Any, targets: Any, optimizer_idx: int = 0) Dict

Abstract method fow loss evaluating.

outputs

Any outputs from model

Type:

Any

targets

Any targets from batch

Type:

Any

Returns:

Dict with keys: [“loss”, “log”]

Return type:

Dict

on_test_epoch_end(val_step_outputs)

Called in the test loop at the very end of the epoch.

on_train_epoch_end(train_step_outputs)

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule OR

  2. Cache data across steps on the attribute(s) of the LightningModule and access them in this hook

on_validation_epoch_end(val_step_outputs)

Called in the validation loop at the very end of the epoch.