Passer au contenu

Mise en œuvre de 'mixup: Beyond Empirical Risk Minimization'. Actuellement testé uniquement pour les données catégorielles, où les variables à prédire sont encodées comme des entiers, et pas encodées en binaire un-contre-tous. Ce callback doit être utilisé simultanément avec `nn_mixup_loss()`.

Utilisation

luz_callback_mixup(alpha = 0.4, ..., run_valid = FALSE, auto_loss = FALSE)

Arguments

alpha

paramètre de la distribution béta utilisée pour échantillonner les coefficients de mélange

...

actuellement non utilisé. Juste pour forcer des arguments nommés.

run_valid

Doit-il s'exécuter aussi pendant la validation ?

auto_loss

Doit-on modifier automatiquement la fonction de coût ? Cela créera une fonction de perte mixup basée sur la fonction de coût. Si TRUE, assurez-vous que votre fonction de coût n'applique pas de réduction. Si run_valid=FALSE, la fonction de coût sera réduite à sa moyenne pendant la validation.

Valeur de retour

Un callback `luz_callback`

Détails

Dans l'ensemble, nous suivons l'implémentation de fastai décrite ici. À savoir,

  • Nous travaillons avec un seul chargeur de donnée, en mélangeant deux observations aléatoirement à partir du même lot.

  • Nous combinons linéairement les pertes calculées pour les deux cibles : loss(output, new_target) = weight * loss(output, target1) + (1-weight) * loss(output, target2)

  • Nous tirons des coefficients de mélange différents pour chaque paire.

  • Nous remplaçons weight par weight = max(weight, 1-weight) pour éviter les répétitions.

if (torch::torch_is_installed()) mixup_callback <- luz_callback_mixup() nn_mixup_loss(), nnf_mixup()Autres luz_callbacks: luz_callback_auto_resume(), luz_callback_csv_logger(), luz_callback_early_stopping(), luz_callback_interrupt(), luz_callback_keep_best_model(), luz_callback_lr_scheduler(), luz_callback_metrics(), luz_callback_mixed_precision(), luz_callback_model_checkpoint(), luz_callback_profile(), luz_callback_progress(), luz_callback_resume_from_checkpoint(), luz_callback_train_valid(), luz_callback() luz_callbacks