Passer au contenu

Crée un nouveau callback

Utilisation

luz_callback(
  name = NULL,
  ...,
  private = NULL,
  active = NULL,
  parent_env = parent.frame(),
  inherit = NULL
)

Arguments

name

nom du callback

...

MĂ©thodes publiques du callback. Le nom des mĂ©thodes est utilisĂ© pour savoir comment elles doivent ĂȘtre appelĂ©es. Voir la section dĂ©tails.

private

Une liste optionnelle de mĂ©thodes privĂ©es, qui peuvent ĂȘtre des fonctions et non-des fonctions.

active

Une liste optionnelle de fonctions d'active binding.

parent_env

Un environnement à utiliser comme parent pour les objets nouvellement créés.

inherit

Un objet R6ClassGenerator duquel hériter; autrement dit, une super-classe. Cela est capturé comme une expression non-évaluée qui sera évaluée dans parent_env chaque fois qu'un objet est instancié.

Valeur de retour

Un callback luz_callback que l'on peut passer Ă  fit.luz_module_generator().

Détails

Let’s implement a callback that prints ‘Iteration n’ (where n is the iteration number) for every batch in the training set and ‘Done’ when an epoch is finished. For that task we use the luz_callback function:

print_callback <- luz_callback(
  name = "print_callback",
  initialize = function(message) {
    self$message <- message
  },
  on_train_batch_end = function() {
    cat("Iteration ", ctx$iter, "\n")
  },
  on_epoch_end = function() {
    cat(self$message, "\n")
  }
)

luz_callback() takes named functions as ... arguments, where the name indicates the moment at which the callback should be called. For instance on_train_batch_end() is called for every batch at the end of the training procedure, and on_epoch_end() is called at the end of every epoch.

The returned value of luz_callback() is a function that initializes an instance of the callback. Callbacks can have initialization parameters, like the name of a file where you want to log the results. In that case, you can pass an initialize method when creating the callback definition, and save these parameters to the self object. In the above example, the callback has a message parameter that is printed at the end of each epoch.

Once a callback is defined it can be passed to the fit function via the callbacks parameter:

fitted <- net %>%
  setup(...) %>%
  fit(..., callbacks = list(
    print_callback(message = "Done!")
  ))

Callbacks can be called in many different positions of the training loop, including combinations of them. Here’s an overview of possible callback breakpoints:

Start Fit
   - on_fit_begin
  Start Epoch Loop
     - on_epoch_begin
    Start Train
       - on_train_begin
      Start Batch Loop
         - on_train_batch_begin
          Start Default Training Step
            - on_train_batch_after_pred
            - on_train_batch_after_loss
            - on_train_batch_before_backward
            - on_train_batch_before_step
            - on_train_batch_after_step
          End Default Training Step:
         - on_train_batch_end
      End Batch Loop
       - on_train_end
    End Train
    Start Valid
       - on_valid_begin
      Start Batch Loop
         - on_valid_batch_begin
          Start Default Validation Step
            - on_valid_batch_after_pred
            - on_valid_batch_after_loss
          End Default Validation Step
         - on_valid_batch_end
      End Batch Loop
       - on_valid_end
    End Valid
      - on_epoch_end
  End Epoch Loop
   - on_fit_end
End Fit

Every step market with on_* is a point in the training procedure that is available for callbacks to be called.

The other important part of callbacks is the ctx (context) object. See help("ctx") for details.

By default, callbacks are called in the same order as they were passed to fit (or predict or evaluate), but you can provide a weight attribute that will control the order in which it will be called. For example, if one callback has weight = 10 and another has weight = 1, then the first one is called after the second one. Callbacks that don’t specify a weight attribute are considered weight = 0. A few built-in callbacks in luz already provide a weight value. For example, the ?luz_callback_early_stopping has a weight of Inf, since in general we want to run it as the last thing in the loop.

Prediction callbacks

You can also use callbacks when using predict(). In this case the supported callback methods are detailed above.

Start predict
 - on_predict_begin
 Start prediction loop
  - on_predict_batch_begin
  - on_predict_batch_end
 End prediction loop
 - on_predict_end
End predict

Evaluate callbacks

Les fonctions de callback peuvent Ă©galement ĂȘtre utilisĂ©es avec evaluate(), dans ce cas, les callbacks utilisĂ©s sont Ă©quivalents Ă  ceux de la boucle de validation lors de l'utilisation de fit():

Début de Validation
 - on_valid_begin
 Début de Boucle de Validation du Lot
  - on_valid_batch_begin
  Début de l'étape de Validation par défaut
   - on_valid_batch_after_pred
   - on_valid_batch_after_loss
  Fin de l'étape de Validation par défaut
  - on_valid_batch_end
 Fin du Lot
 - on_valid_end
Fin de Validation

Exemples

print_callback <- luz_callback(
 name = "print_callback",
 on_train_batch_end = function() {
   cat("Iteration ", ctx$iter, "\n")
 },
 on_epoch_end = function() {
   cat("Done!\n")
 }
)
#> Error in luz_callback(name = "print_callback", on_train_batch_end = function() {    cat("Iteration ", ctx$iter, "\n")}, on_epoch_end = function() {    cat("Done!\n")}): impossible de trouver la fonction "luz_callback"