pymor.algorithms.ml.nn.train¶
Module Contents¶
- pymor.algorithms.ml.nn.train.multiple_restarts_training(training_data, validation_data, neural_network, target_loss=None, max_restarts=10, log_loss_frequency=0, training_parameters={})[source]¶
Algorithm that performs multiple restarts of neural network training.
This method either performs a predefined number of restarts and returns the best trained network or tries to reach a given target loss and stops training when the target loss is reached.
See
train_neural_networkfor more information on the parameters.- Parameters:
training_data – Data to use during the training phase.
validation_data – Data to use during the validation phase.
neural_network – The neural network to train (parameters will be reset after each restart).
target_loss – Loss to reach during training (if
None, the network with the smallest loss is returned).max_restarts – Maximum number of restarts to perform.
log_loss_frequency – Frequency of epochs in which to log the current validation and training loss. If
0, no intermediate logging of losses is done.training_parameters – Additional parameters for the training algorithm, see
train_neural_networkfor more information.
- Returns:
best_neural_network – The best trained neural network.
losses – The corresponding losses.
- Raises:
NeuralNetworkTrainingError – Raised if prescribed loss can not be reached within the given number of restarts.
- pymor.algorithms.ml.nn.train.train_neural_network(training_data, validation_data, neural_network, training_parameters={}, log_loss_frequency=0)[source]¶
Training algorithm for artificial neural networks.
Trains a single neural network using the given training and validation data.
- Parameters:
training_data – Data to use during the training phase. Has to be a list of tuples, where each tuple consists of two elements that are either PyTorch-tensors (
torch.DoubleTensor) orNumPy arraysor pyMOR data structures that haveto_numpy()implemented. The first element contains the input data, the second element contains the target values.validation_data – Data to use during the validation phase. Has to be a list of tuples, where each tuple consists of two elements that are either PyTorch-tensors (
torch.DoubleTensor) orNumPy arraysor pyMOR data structures that haveto_numpy()implemented. The first element contains the input data, the second element contains the target values.neural_network – The neural network to train (can also be a pre-trained model). Has to be a PyTorch-Module.
training_parameters – Dictionary with additional parameters for the training routine like the type of the optimizer, the (maximum) number of epochs, the batch size, the learning rate or the loss function to use. Possible keys are
'optimizer'(an optimizer from the PyTorchoptimpackage; if not provided, the LBFGS-optimizer is taken as default),'epochs'(an integer that determines the number of epochs to use for training the neural network (if training is not interrupted prematurely due to early stopping); if not provided, 1000 is taken as default value),'batch_size'(an integer that determines the number of samples to pass to the optimizer at once; if not provided, 20 is taken as default value; not used in the case of the LBFGS-optimizer since LBFGS does not support mini-batching),'learning_rate'(a positive real number used as the (initial) step size of the optimizer; if not provided, 1 is taken as default value),'loss_function'(a loss function from PyTorch; if not provided, the MSE loss is taken as default),'lr_scheduler_config'(a dictionary containing the keys'scheduler'(a learning rate scheduler from the PyTorchoptim.lr_schedulerpackage),'interval'('epoch'or'batch'clarifying if the scheduler steps epochwise or batchwise),'params'(a dictionary of additional parameters for the learning rate scheduler); if not provided orNone, no learning rate scheduler is used),'es_scheduler_params'(a dictionary of additional parameters for the early stopping scheduler), and'weight_decay'(non-negative real number that determines the strength of the l2-regularization; if not provided or 0., no regularization is applied).log_loss_frequency – Frequency of epochs in which to log the current validation and training loss. If
0, no intermediate logging of losses is done.
- Returns:
best_neural_network – The best trained neural network with respect to validation loss.
losses – The corresponding losses as a dictionary with keys
'full'(for the full loss containing the training and the validation average loss),'train'(for the average loss on the training parameters), and'val'(for the average loss on the validation parameters).