Passer au contenu

La fonction d'initialiation, utilisée pour définir les attributs et les méthodes importantes pour que le nn_modules fonctionne avec luz.

Utilisation

setup(module, loss = NULL, optimizer = NULL, metrics = NULL, backward = NULL)

Arguments

module

(nn_module) Le nn_module à utiliser.

loss

(fonction, optionnel) Une fonction avec la signature function(input, target). Elle est uniquement requise si votre nn_module n'implémente pas une fonction de coût nomée loss.

optimizer

(optimiseur torch, optionnel) Une fonction avec la signature function(parameters, ...) qui est utilisée pour initialiser un optimiseur à partir des paramètres du modèle.

metrics

(liste, optionnel) Une liste de métriques à suivre pendant la procédure d'entraînement. Si vous voulez que des métriques soient évaluées uniquement pendant l'entraînement ou la validation, vous pouvez passer un objet luz_metric_set() pour spécifier les métriques utilisées à chaque étape.

backward

(fonction) Une fonction qui prend des valeurs retournées par la fonction de coût comme paramètres. Elle doit appeler $backward() ou torch::autograd_backward(). En général, vous n'avez pas besoin de définir ce paramètre sauf si vous devez personnaliser comment luz appelle la méthode backward(), par exemple, si vous devez ajouter des arguments supplémentaires à l'appel de la méthode. Notez que cela devient une méthode du nn_module, donc elle peut être utilisée par votre step() personnalisé si vous le redéfinissez.

Valeur de retour

Un module luz qui peut être entraîné avec fit().

Détails

Elle s'assure que le module ait tous les ingrédients nécessaires pour être entraîné.

Note

Elle ajoute également un champ de device actif qui peut être utilisé pour interroger le device courant du module dans les méthodes, comme par exemple self$device. Cela est utile lorsque ctx() n'est pas disponible, par exemple, lorsque vous appelez des méthodes en dehors de luz. Les utilisateurs peuvent personnaliser la valeur par défaut en implémentant une méthode active de device dans le module d'entrée.

Voir également