Initialise l'usage d'un nn_module
avec luz 🌐
Set's up a nn_module
to use with luz
🌐
Set's up ann_module
to use with luzsetup.Rd
La fonction d'initialiation, utilisée pour définir les attributs et les méthodes
importantes pour que le nn_modules
fonctionne avec luz.
Arguments
- module
(
nn_module
) Lenn_module
à utiliser.- loss
(fonction, optionnel) Une fonction avec la signature
function(input, target)
. Elle est uniquement requise si votrenn_module
n'implémente pas une fonction de coût noméeloss
.- optimizer
(optimiseur torch, optionnel) Une fonction avec la signature
function(parameters, ...)
qui est utilisée pour initialiser un optimiseur à partir des paramètres du modèle.- metrics
(liste, optionnel) Une liste de métriques à suivre pendant la procédure d'entraînement. Si vous voulez que des métriques soient évaluées uniquement pendant l'entraînement ou la validation, vous pouvez passer un objet luz_metric_set() pour spécifier les métriques utilisées à chaque étape.
- backward
(fonction) Une fonction qui prend des valeurs retournées par la fonction de coût comme paramètres. Elle doit appeler
$backward()
outorch::autograd_backward()
. En général, vous n'avez pas besoin de définir ce paramètre sauf si vous devez personnaliser comment luz appelle la méthodebackward()
, par exemple, si vous devez ajouter des arguments supplémentaires à l'appel de la méthode. Notez que cela devient une méthode dunn_module
, donc elle peut être utilisée par votrestep()
personnalisé si vous le redéfinissez.
Note
Elle ajoute également un champ de device
actif qui peut être utilisé pour
interroger le device
courant du module dans les méthodes, comme par exemple
self$device
. Cela est utile lorsque ctx()
n'est pas disponible,
par exemple, lorsque vous appelez des méthodes en dehors de luz
. Les utilisateurs
peuvent personnaliser la valeur par défaut en implémentant une méthode active de device
dans le module d'entrée.
Voir également
Autres entraînements:
evaluate()
,
fit.luz_module_generator()
,
predict.luz_module_fitted()