Passer au contenu

Ce callback activera l'entraînement du modèle torch::local_autocast() pendant la phase forward() et pendant le calcul de la fonction de coût. Il désactivera ensuite l'autocast et normalisera la fonction de coût avant la phase backward() et opt$step(). Pour en savoir plus, voir ici.

Utilisation

luz_callback_mixed_precision(...)

Arguments

...

Passé à torch::cuda_amp_grad_scaler().

Valeur de retour

Un callback de type luz_callback