Passer au contenu

Crée une nouvelle métrique luz

Utilisation

luz_metric(
  name = NULL,
  ...,
  private = NULL,
  active = NULL,
  parent_env = parent.frame(),
  inherit = NULL
)

Arguments

name

string naming the new metric.

...

named list of public methods. You should implement at least initialize, update and compute. See the details section for more information.

private

An optional list of private members, which can be functions and non-functions.

active

An optional list of active binding functions.

parent_env

An environment to use as the parent of newly-created objects.

inherit

A R6ClassGenerator object to inherit from; in other words, a superclass. This is captured as an unevaluated expression which is evaluated in parent_env each time an object is instantiated.

Valeur de retour

Renvoie la nouvelle métrique luz.

Détails

Pour implémenter un nouveau luz_metric, il faut implémenter 3 méthodes:

  • initialize: définit l'état initial de la métrique. Cette fonction est appelée pour chaque époque dans les boucles d'entraînement et de validation.

  • update: met à jour l'état interne de la métrique. Cette fonction est appelée à chaque étape d'entraînement et de validation avec les prédictions obtenues par le modèle et les valeurs cibles obtenues à partir du chargeur de donnée.

  • compute: utilise l'état interne pour calculer les valeurs du métrique. Cette fonction est appelée chaque fois que nous devons obtenir la valeur actuelle du métrique. Eg, elle est appelée chaque étape d'entraînement pour les métriques affichées dans la bare de progression, mais uniquement appelée une fois par époque pour enregistrer sa valeur lorsque la bare de progression n'est pas affichée.

Optionnellement, vous pouvez implémenter un champ abbrev qui donne à la métrique un abrégé que l'on utilisera lors de l' affichage d'informations sur les métriques dans le console ou enregistrer. Si aucun abbrev n'est passé, le nom de classe sera utilisé.

Voyons comment implémenter luz_metric_accuracy pour voir comment implémenter une nouvelle métrique:

luz_metric_accuracy <- luz_metric(
 # Un résumé à afficher dans les barres de progression, ou n# lorsque l'on affiche la progression
  abbrev = "Acc",
  # Configuration initiale pour le métrique. Les métriques sont initialisées
  # à chaque époque, pour les boucles d'entraînement et de validation
  initialize = function() {
    self$correct <- 0
    self$total <- 0
  },
 # Exécuter à chaque étape d'entraînement ou de validation et mettre à jour
 # l'état interne. La fonction update prend `preds` et `target` en paramètres.
 update = function(preds, target) {
    pred <- torch::torch_argmax(preds, dim = 2)
 self$correct <- self$correct + (pred == target)$
      to(dtype = torch::torch_float())$
 sum()$
      item()
       self$total <- self$total + pred$numel()
  },
 # Utiliser l'état interne pour interroger la valeur du métrique
  compute = function() {
 self$correct/self$total
  }
)

strongNote : Il est recommandé que le métrique compute renvoie des valeurs régulières R au lieu de tenseurs torch car c'est ce qui est attendu par les autres parties de luz.

Exemples

luz_metric_accuracy <- luz_metric(
  # An abbreviation to be shown in progress bars, or
  # when printing progress
  abbrev = "Acc",
  # Initial setup for the metric. Metrics are initialized
  # every epoch, for both training and validation
  initialize = function() {
    self$correct <- 0
    self$total <- 0
  },
  # Run at every training or validation step and updates
  # the internal state. The update function takes `preds`
  # and `target` as parameters.
  update = function(preds, target) {
    pred <- torch::torch_argmax(preds, dim = 2)
    self$correct <- self$correct + (pred == target)$
      to(dtype = torch::torch_float())$
      sum()$
      item()
    self$total <- self$total + pred$numel()
  },
  # Use the internal state to query the metric value
  compute = function() {
    self$correct/self$total
  }
)
#> Error in luz_metric(abbrev = "Acc", initialize = function() {    self$correct <- 0    self$total <- 0}, update = function(preds, target) {    pred <- torch::torch_argmax(preds, dim = 2)    self$correct <- self$correct + (pred == target)$to(dtype = torch::torch_float())$sum()$item()    self$total <- self$total + pred$numel()}, compute = function() {    self$correct/self$total}): impossible de trouver la fonction "luz_metric"