ClassificationLearner

class easypl.learners.classification.ClassificationLearner(model: Optional[Union[Module, List[Module]]] = None, loss: Optional[Union[Module, List[Module]]] = None, optimizer: Optional[Union[WrapperOptimizer, List[WrapperOptimizer]]] = None, lr_scheduler: Optional[Union[WrapperScheduler, List[WrapperScheduler]]] = None, train_metrics: Optional[List[Metric]] = None, val_metrics: Optional[List[Metric]] = None, test_metrics: Optional[List[Metric]] = None, data_keys: Optional[List[str]] = None, target_keys: Optional[List[str]] = None, multilabel: bool = False)

Classification learner.

model

torch.nn.Module model.

Type:

Optional[Union[torch.nn.Module, List[torch.nn.Module]]]

loss

torch.nn.Module loss function.

Type:

Optional[Union[torch.nn.Module, List[torch.nn.Module]]]

optimizer

Optimizer wrapper object.

Type:

Optional[Union[WrapperOptimizer, List[WrapperOptimizer]]]

lr_scheduler

Scheduler object for lr scheduling.

Type:

Optional[Union[WrapperScheduler, List[WrapperScheduler]]]

train_metrics

List of train metrics.

Type:

Optional[List[Metric]]

val_metrics

List of validation metrics.

Type:

Optional[List[Metric]]

test_metrics

List of test metrics.

Type:

Optional[List[Metric]]

data_keys

List of data keys

Type:

Optional[List[str]]

target_keys

List of target keys

Type:

Optional[List[str]]

multilabel

If classification task is multilabel.

Type:

bool

forward(samples: Tensor) Tensor

Standart method for forwarding model. .. attribute:: samples

Image tensor.

type:

torch.Tensor

Returns:

Output from model.

Return type:

torch.Tensor

get_outputs(batch: Dict, optimizer_idx: int = 0) Dict

Abtract method for selecting and preprocessing outputs from batch

batch

Batch in step

Type:

Dict

optimizer_idx

Index of optimizer

Type:

int

Returns:

Dict with keys: [“loss”, “metric”, “log”]

Return type:

Dict

get_targets(batch: Dict, optimizer_idx: int = 0) Dict

Method for selecting and preprocessing targets from batch

batch

Batch in step

Type:

Dict

optimizer_idx

Index of optimizer

Type:

int

Returns:

Dict with keys: [“loss”, “metric”, “log”]

Return type:

Dict

loss_step(outputs: Tensor, targets: Tensor, optimizer_idx: int = 0) Dict

Method fow loss evaluating.

outputs

Outputs from model

Type:

torch.Tensor

targets

Targets from batch

Type:

torch.Tensor

optimizer_idx

Index of optimizer

Type:

int

Returns:

Dict with keys: [“loss”, “log”]

Return type:

Dict