GANLearner
- class easypl.learners.gan.GANLearner(model: Optional[List[Module]] = None, loss: Optional[List[Module]] = None, optimizer: Optional[List[WrapperOptimizer]] = None, lr_scheduler: Optional[Union[WrapperScheduler, List[WrapperScheduler]]] = None, train_metrics: Optional[List[Metric]] = None, val_metrics: Optional[List[Metric]] = None, test_metrics: Optional[List[Metric]] = None, data_keys: Optional[List[str]] = None, target_keys: Optional[List[str]] = None)
Simple example for generative adversarial networks learner.
- model
Generator and discriminator
- Type:
Optional[List[torch.nn.Module]]
- loss
torch.nn.Module losses function.
- Type:
Optional[List[torch.nn.Module]]
- optimizer
Optimizers wrapper object.
- Type:
Optional[List[WrapperOptimizer]]
- lr_scheduler
Scheduler object for lr scheduling.
- Type:
Optional[Union[WrapperScheduler, List[WrapperScheduler]]]
- train_metrics
List of train metrics.
- Type:
Optional[List[Metric]]
- val_metrics
List of validation metrics.
- Type:
Optional[List[Metric]]
- test_metrics
List of test metrics.
- Type:
Optional[List[Metric]]
- data_keys
List of data keys
- Type:
Optional[List[str]]
- target_keys
List of target keys
- Type:
Optional[List[str]]
- forward(samples: Tensor) Tensor
Standart method for forwarding model. .. attribute:: samples
Image tensor.
- type:
torch.Tensor
- Returns:
Output from model.
- Return type:
torch.Tensor
- get_outputs(batch: Dict, optimizer_idx: int = 0) Dict
Abtract method for selecting and preprocessing outputs from batch
- batch
Batch in step
- Type:
Dict
- optimizer_idx
Index of optimizer
- Type:
int
- Returns:
Dict with keys: [“loss”, “metric”, “log”]
- Return type:
Dict