Passer au contenu

Évalue un modèle entraîné sur un jeu de données

Utilisation

evaluate(
  object,
  data,
  ...,
  metrics = NULL,
  callbacks = list(),
  accelerator = NULL,
  verbose = NULL,
  dataloader_options = NULL
)

Arguments

object

Un modèle entraîné.

data

(dataloader, dataset ou liste) Un chargeur de données créé avec torch::dataloader() sur lequel évaluer le modèle, ou un dataset créé avec torch::dataset() ou une liste. Les chargeur de donnéess et les datasets doivent retourner une liste avec au plus 2 items. Le premier item sera utilisé comme prédicteurs pour le module et le second sera utilisé comme variable à prédire pour la fonction de perte.

...

Inutilisé

metrics

Une liste de métriques de luz à appliquer pendant l'évaluation. Si NULL (par défaut) alors les mêmes métriques que les métriques d'entraînement sont évaluées.

callbacks

(facultatif) Une liste de callbacks définis avec luz_callback() qui seront appelés pendant la procédure d'entraînement. Les callbacks luz_callback_metrics(), luz_callback_progress() et luz_callback_train_valid() sont toujours ajoutés par défaut.

accelerator

Un accélérateur accelerator() à utiliser pour le calcul des objets tels que les modules nn, les optimiseurs et les batch de données.

verbose

(booléen) La procédure d'entraînement doit-elle produire des messages dans la console. Par défaut, elle produira des messages s'il y a une interface graphique (c'est-à-dire si interactive() est vrai), sinon elle ne produira pas de messages.

dataloader_options

Des options utilisées lors de la création d'un chargeur de données. Voir torch::dataloader(). Par défautshuffle=TRUE pour les données d'entraînement et batch_size=32. Il y aura une erreur si non NULL et si data est déja un chargeur de données.

Détails

Une fois que vous avez entraîné un modèle, vous pouvez évaluer sa performance sur un autre jeu de données. Pour cela, luz fournit la fonction evaluate. Cette dernière prend en argument un modèle entraîné et un jeu de données, puis calcule les métriques liées au modèle.

La fonction evaluate retourne un objet luz_module_evaluation, que vous pouvez consulter grâce à la fonction get_metrics() ou simplement print pour voir les résultats.

Par exemple:

evaluation <- fitted %>% evaluate(data = valid_dl)
metrics <- get_metrics(evaluation)
print(evaluation)

## A `luz_module_evaluation`
## -- Results ---------------------------------------------------------------------
## loss: 1.5146
## mae: 1.0251
## mse: 1.5159
## rmse: 1.2312

Voir également