Entraîne un nn_module
🌐
Fit a nn_module
🌐
Fit ann_module
fit.luz_module_generator.Rd
Entraîne un nn_module
Utilisation
# Méthode S3 pour la classe luz_module_generator
fit(
object,
data,
epochs = 10,
callbacks = NULL,
valid_data = NULL,
accelerator = NULL,
verbose = NULL,
...,
dataloader_options = NULL
)
Arguments
- object
Un
nn_module
qui est passé par la commandesetup()
.- data
(dataloader, dataset ou list) Un chargeur de donnée créé avec
torch::dataloader()
utilisé pour l'entraînement du modèle, ou un jeu de données créé avectorch::dataset()
ou une liste. Les chargeurs de donnée et les jeux de données doivent renvoyer une liste contenant au plus 2 éléments. Le premier élément sera utilisé comme entrée pour le module et le second comme variable à prédire pour la fonction de perte.- epochs
(entier) Le nombre maximal d'époques d'entraînement du modèle. Si une valeur unique est fournie, elle sera prise comme
max_epochs
etmin_epochs
sera à 0. Si un vecteur de deux nombres est fourni, la première valeur seramin_epochs
et la seconde valeur seramax_epochs
. Le nombre minimum et maximum d'époques sont inclus dans l'objet context sous forme dectx$min_epochs
etctx$max_epochs
, respectivement.- callbacks
(optionnel) Une liste de callbacks définis avec
luz_callback()
qui seront appelés pendant la procédure d'entraînement. Les callbacksluz_callback_metrics()
,luz_callback_progress()
etluz_callback_train_valid()
sont toujours ajoutés par défaut.- valid_data
(dataloader, dataset, liste ou réel; optionnel) Un chargeur de donnée créé avec
torch::dataloader()
ou un jeu de données créé avectorch::dataset()
qui sera utilisé pendant la procédure de validation. Ils doivent retourner une liste contenant (input, target). Sidata
est untorch::dataset()
ou une liste, vous pouvez également fournir une valeur numérique entre 0 et 1 - et dans ce cas, une échantillonnage aléatoire avec taille correspondante à celle proportionnelle à partir dedata
sera utilisée pour la validation.- accelerator
(accelerator, optional) Un objet
accelerator()
optionnel utilisé pour configurer le device cible du calcul pour des composants tels que les modules nn, les optimiseurs et batches de données.- verbose
(booléen, optionnel) La procédure d'entraînement doit-elle produire ses messages vers la console pendant l'entraînement. Par défaut, elle produira des message si
interactive()
estTRUE
, sinon elle ne publiera pas vers la console.- ...
Inutilisé
- dataloader_options
Options utilisées lors de la création d'un chargeur de donnée. Voir
torch::dataloader()
.shuffle=TRUE
par défaut pour les données d'entraînement etbatch_size=32
par défaut. Il y aura erreur si ce n'est pasNULL
quanddata
est déjà un chargeur de donnée.
Valeur de retour
Un objet entraîné qui peut être enregistré avec luz_save()
,
affiché avec print()
et visualisé avec plot()
.
Voir également
predict.luz_module_fitted()
pour savoir comment créer des prédictions. setup()
pour trouver comment
créer des modules qui peuvent être entraînés avec fit
.
Autres entraînement
evaluate()
, predict.luz_module_fitted()
, setup()