Passer au contenu

Entraîne un nn_module

Utilisation

# Méthode S3 pour la classe luz_module_generator
fit(
  object,
  data,
  epochs = 10,
  callbacks = NULL,
  valid_data = NULL,
  accelerator = NULL,
  verbose = NULL,
  ...,
  dataloader_options = NULL
)

Arguments

object

Un nn_module qui est passé par la commande setup().

data

(dataloader, dataset ou list) Un chargeur de donnée créé avec torch::dataloader() utilisé pour l'entraînement du modèle, ou un jeu de données créé avec torch::dataset() ou une liste. Les chargeurs de donnée et les jeux de données doivent renvoyer une liste contenant au plus 2 éléments. Le premier élément sera utilisé comme entrée pour le module et le second comme variable à prédire pour la fonction de perte.

epochs

(entier) Le nombre maximal d'époques d'entraînement du modèle. Si une valeur unique est fournie, elle sera prise comme max_epochs et min_epochs sera à 0. Si un vecteur de deux nombres est fourni, la première valeur sera min_epochs et la seconde valeur sera max_epochs. Le nombre minimum et maximum d'époques sont inclus dans l'objet context sous forme de ctx$min_epochs et ctx$max_epochs, respectivement.

callbacks

(optionnel) Une liste de callbacks définis avec luz_callback() qui seront appelés pendant la procédure d'entraînement. Les callbacks luz_callback_metrics(), luz_callback_progress() et luz_callback_train_valid() sont toujours ajoutés par défaut.

valid_data

(dataloader, dataset, liste ou réel; optionnel) Un chargeur de donnée créé avec torch::dataloader() ou un jeu de données créé avec torch::dataset() qui sera utilisé pendant la procédure de validation. Ils doivent retourner une liste contenant (input, target). Si data est un torch::dataset() ou une liste, vous pouvez également fournir une valeur numérique entre 0 et 1 - et dans ce cas, une échantillonnage aléatoire avec taille correspondante à celle proportionnelle à partir de data sera utilisée pour la validation.

accelerator

(accelerator, optional) Un objet accelerator() optionnel utilisé pour configurer le device cible du calcul pour des composants tels que les modules nn, les optimiseurs et batches de données.

verbose

(booléen, optionnel) La procédure d'entraînement doit-elle produire ses messages vers la console pendant l'entraînement. Par défaut, elle produira des message si interactive() est TRUE, sinon elle ne publiera pas vers la console.

...

Inutilisé

dataloader_options

Options utilisées lors de la création d'un chargeur de donnée. Voir torch::dataloader(). shuffle=TRUE par défaut pour les données d'entraînement et batch_size=32 par défaut. Il y aura erreur si ce n'est pas NULL quand data est déjà un chargeur de donnée.

Valeur de retour

Un objet entraîné qui peut être enregistré avec luz_save(), affiché avec print() et visualisé avec plot().

Voir également

predict.luz_module_fitted() pour savoir comment créer des prédictions. setup() pour trouver comment créer des modules qui peuvent être entraînés avec fit. Autres entraînement evaluate(), predict.luz_module_fitted(), setup()