Passer au contenu

Luz est une interface (API) de haut niveau pour torch, conçue pour être très flexible, et ainsi d’avoir un contrôle complêt sur la boucle d’entraînement.

Dans la vignette("get-started"), nous avons vu les bases de luz et comment modifier facilement des parties de la boucle d’entraînement à l’aide de rappels (callbacks) et de métriques personnalisées. Dans ce document, nous allons décrire en détail comment luz permet à l’utilisateur d’avoir un contrôle très fin sur la boucle d’entraînement.

Outre l’utilisation de callbacks, il existe trois autres façons d’utiliser luz (en fonction du niveau de contrôle nécessaire) :

  • Optimiseurs ou fonctions de perte multiples : Vous pouvez optimiser deux fonctions de perte, chacune avec son propre optimiseur, tout en maintenant la possibilité de ne pas modifier les appels backward() - zero_grad() et step(). C’est commun dans les modèles comme les GANs (Réseaux Adversaires Génératifs) lorsqu’on a des réseaux neuronaux concurrents entraînés avec différentes pertes et optimiseurs.

  • Étapes complètement flexibles : Vous pouvez contrôler complètement comment appeler backward(), zero_grad() et step(). Vous pouvez aussi avoir un contrôle plus grand sur la calcul de gradients. Par exemple, vous pouvez utiliser des ‘tailles de lot virtuelles’, où vous accumulez les gradients pendant quelques étapes avant d’actualiser les poids.

  • Boucles complètement flexibles : Votre boucle d’entraînement peut être entièrement redessinée, et dans ce cas vous utiliser luz pour gérer uniquement le déplacement des chargeurs de données, optimiseurs et modèles vers un périphérique. Voir la vignette("accelerator").

Considérons une version simplifiée du réseau net que nous avons implémenté dans la vignette("get-started") :

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc2 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  }
)

En utilisant l’API de haut niveau de luz, on peut l’entraîner comme suit :

fitted <- net %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_accuracy
    )
  ) %>%
  fit(train_dl, epochs = 10, valid_data = test_dl)

Optimiseurs multiples

Supposons que notre expérience consiste à ajuster la première couche de connexion complète avec un taux d’apprentissage de 0.1 et la deuxième avec un taux d’apprentissage de 0.01. Nous allons minimiser la même nn_cross_entropy_loss() pour les deux couches, mais pour la première, disons que nous voulons ajouter une régularisation L1.

Pour y arriver avec luz, on va ajouter deux méthodes à notre module net :

  • set_optimizers: qui renvoie une liste d’optimiseurs en fonction du contexte ctx.

  • loss: qui calcule une fonction de coût différente suivant l’optimiseur.

Voyons le résultat dans le code :

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  },
  set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
    list(
      opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
      opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
    )
  },
  loss = function(input, target) {
    pred <- ctx$model(input)
  
    if (ctx$opt_name == "opt_fc1") 
      # ajout d'une régularisation L1 sur la couche 1
      nnf_cross_entropy(pred, target) + torch_norm(self$fc1$weight, p = 1)
    else if (ctx$opt_name == "opt_fc2")
      nnf_cross_entropy(pred, target)
  }
)

Il faut noter que l’initialisation des optimiseurs se fera avec le résultat de la méthode set_optimizers() qui est une liste. Et donc ici, nous aurons bien deux optimiseurs différents, appliqués chacuns à des paramêtres du modèle différents, et avec un taux d’apprentissage spécifique.

La méthode loss() a en charge le calcul de la fonction de coût, pour faire la retro-propagation des gradients et mettre à jour les poids du modèle. Cette méthode loss() accède à l’objet ctx qui contient une variable opt_name décrivant l’optimiseur en cours d’utilisation. On note que cette fonction sera appellée une fois par optimiseur et par étape de la boucke d’entrainement et de la boucle de validation. Référez vous à help("ctx") pour une information complète sur l’objet contexte.

On peut maintenant utiliser setup et fit avec ce module, mais sans préciser ni les paramêtres d’optimiseurs ni de fonction de coût.

fitted <- net %>% 
  setup(metrics = list(luz_metric_accuracy)) %>% 
  fit(train_dl, epochs = 10, valid_data = test_dl)

Maintenant, nous allons re-implementer ce même modèle en utilisant une approche légèrement plus flexible qui consiste à surcharger l’étape d’apprentissage et de validation.

Étape complètement flexible

Au lieu d’implémenter la méthode loss(), nous pouvons implémenter la méthode step(). Cela permet de modifier en toute flexibilité ce qui se passe lors de l’apprentissage et de la validation pour chaque lot du jeu de données. C’est votre code qui est maintenant responsable de mettre à jour les poids du modèle à chaque appel des optimiseurs par retro-propagation des gradients.

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  },
  set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
    list(
      opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
      opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
    )
  },
  step = function() {
    ctx$loss <- list()
    for (opt_name in names(ctx$optimizers)) {
    
      ctx$pred <- ctx$model(ctx$input)
      opt <- ctx$optimizers[[opt_name]]
      loss <- nnf_cross_entropy(pred, target)
      
      if (opt_name == "opt_fc1") {
        # ajout d'une régularisation L1 sur la couche 1
        loss <- nnf_cross_entropy(pred, target) + 
          torch_norm(self$fc1$weight, p = 1)
      }
        
      if (ctx$training) {
        opt$zero_grad()
        loss$backward()
        opt$step()  
      }
      
      ctx$loss[[opt_name]] <- loss$detach()
    }
  }
)

Les choses importantes à noter ici sont :

  • La méthode step() est utilisée pour l’apprentissage et la validation. Vous devez être prudent car vous ne devez modifier les poids que pendant la phase d’apprentissage. Encore une fois, vous pouvez obtenir des informations complètes sur l’objet de contexte avec help("ctx").

  • ctx$optimizers est une liste nommée contenant chaque optimiseur qui a été créé lorsqu’on a appellé la méthode set_optimizers().

  • Vous devez assurer le suivi les pertes en cours de modification en les enregistrant dans une liste nommée dans ctx$loss. Conformément à la convention, nous réutilisons le nom que l’optimiseur auquel elle se réfère. Il est recommandé de les déconnecter avec $detach() avant leur sauvegarde pour réduire la consommation de mémoire.

  • Les rappels qui seraient déclenchés dans la méthode step() par défaut comme on_train_batch_after_pred, on_train_batch_after_loss, etc., ne seront pas automatiquement appelés. Vous pouvez toujours les déclencher manuellement en ajoutant ctx$call_callbacks("<nom du rappel>)à l'intérieur de votre étape d'apprentissage. Étudiez le code defit_one_batch()etvalid_one_batch` pour identifier tous les rappels qui ne seront pas déclenchés.

  • Si vous voulez que les métriques luz fonctionnent avec votre méthode step() personnalisée, vous devez assigner ctx$pred aux prédictions du modèle car les métriques seront toujours appelées avec metric$update(ctx$pred, ctx$target).

Étapes suivantes

Dans cet article, vous avez appris à personnaliser l’étape de formation du boucle d’apprentissage en utilisant les fonctionnalités en couche de luz.

Luz permet également des modifications plus flexibles de la boucle d’apprentissage décrites dans le vignette “Accélérateur” (vignette("accelerator")).

Vous devriez maintenant pouvoir suivre les exemples marqués avec les catégories ‘intermédiaire’ et ‘avancé’ dans la galerie des exemples de luz