Crée une nouvelle métrique luz 🌐
Creates a new luz metric
🌐
Creates a new luz metriczz_metric.Rd
Crée une nouvelle métrique luz
Utilisation
luz_metric(
name = NULL,
...,
private = NULL,
active = NULL,
parent_env = parent.frame(),
inherit = NULL
)
Arguments
- name
string naming the new metric.
- ...
named list of public methods. You should implement at least
initialize
,update
andcompute
. See the details section for more information.- private
An optional list of private members, which can be functions and non-functions.
- active
An optional list of active binding functions.
- parent_env
An environment to use as the parent of newly-created objects.
- inherit
A R6ClassGenerator object to inherit from; in other words, a superclass. This is captured as an unevaluated expression which is evaluated in
parent_env
each time an object is instantiated.
Détails
Pour implémenter un nouveau luz_metric
, il faut implémenter 3 méthodes:
initialize
: définit l'état initial de la métrique. Cette fonction est appelée pour chaque époque dans les boucles d'entraînement et de validation.update
: met à jour l'état interne de la métrique. Cette fonction est appelée à chaque étape d'entraînement et de validation avec les prédictions obtenues par le modèle et les valeurs cibles obtenues à partir du chargeur de donnée.compute
: utilise l'état interne pour calculer les valeurs du métrique. Cette fonction est appelée chaque fois que nous devons obtenir la valeur actuelle du métrique. Eg, elle est appelée chaque étape d'entraînement pour les métriques affichées dans la bare de progression, mais uniquement appelée une fois par époque pour enregistrer sa valeur lorsque la bare de progression n'est pas affichée.
Optionnellement, vous pouvez implémenter un champ abbrev
qui donne à la métrique un abrégé que l'on utilisera lors de l'
affichage d'informations sur les métriques dans le console ou enregistrer.
Si aucun abbrev
n'est passé, le nom de classe sera utilisé.
Voyons comment implémenter luz_metric_accuracy
pour voir comment implémenter une nouvelle métrique:
luz_metric_accuracy <- luz_metric(
# Un résumé à afficher dans les barres de progression, ou n# lorsque l'on affiche la progression
abbrev = "Acc",
# Configuration initiale pour le métrique. Les métriques sont initialisées
# à chaque époque, pour les boucles d'entraînement et de validation
initialize = function() {
self$correct <- 0
self$total <- 0
},
# Exécuter à chaque étape d'entraînement ou de validation et mettre à jour
# l'état interne. La fonction update prend `preds` et `target` en paramètres.
update = function(preds, target) {
pred <- torch::torch_argmax(preds, dim = 2)
self$correct <- self$correct + (pred == target)$
to(dtype = torch::torch_float())$
sum()$
item()
self$total <- self$total + pred$numel()
},
# Utiliser l'état interne pour interroger la valeur du métrique
compute = function() {
self$correct/self$total
}
)
strongNote : Il est recommandé que le métrique compute
renvoie des valeurs régulières R au lieu de tenseurs torch car c'est ce qui est attendu par les autres parties de luz.
Voir également
Autres métriques de luz:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
Exemples
luz_metric_accuracy <- luz_metric(
# An abbreviation to be shown in progress bars, or
# when printing progress
abbrev = "Acc",
# Initial setup for the metric. Metrics are initialized
# every epoch, for both training and validation
initialize = function() {
self$correct <- 0
self$total <- 0
},
# Run at every training or validation step and updates
# the internal state. The update function takes `preds`
# and `target` as parameters.
update = function(preds, target) {
pred <- torch::torch_argmax(preds, dim = 2)
self$correct <- self$correct + (pred == target)$
to(dtype = torch::torch_float())$
sum()$
item()
self$total <- self$total + pred$numel()
},
# Use the internal state to query the metric value
compute = function() {
self$correct/self$total
}
)
#> Error in luz_metric(abbrev = "Acc", initialize = function() { self$correct <- 0 self$total <- 0}, update = function(preds, target) { pred <- torch::torch_argmax(preds, dim = 2) self$correct <- self$correct + (pred == target)$to(dtype = torch::torch_float())$sum()$item() self$total <- self$total + pred$numel()}, compute = function() { self$correct/self$total}): impossible de trouver la fonction "luz_metric"