Passer au contenu

Recherche du taux d'apprentissage

Utilisation

lr_finder(
  object,
  data,
  steps = 100,
  start_lr = 1e-07,
  end_lr = 0.1,
  log_spaced_intervals = TRUE,
  ...,
  verbose = NULL
)

Arguments

object

Un module nn qui a été configuré par setup().

data

(dataloader) Un chargeur de donnée créé avec torch::dataloader(), utilisé pour le recherche de taux d'apprentissage.

steps

(entier) Le nombre d'itérations pour la recherche de taux d'apprentissage. Défaut : 100.

start_lr

(réel) La limite basse du taux d'apprentissage.

end_lr

(réel) La limite haute du taux d'apprentissage.

log_spaced_intervals

(booléen) Doit-on découper logarithmiquement l'intervalle entre start_lr et end_lr (Si FALSE : intervalles uniformes). Défaut : TRUE

...

Autres arguments passés à fit.

verbose

Doit-on afficher un barre de progression pendant la recherche.

Valeur de retour

Un dataframe de deux colonnes : taux d'apprentissage et valeur de la fonction de perte

Exemples

if (torch::torch_is_installed()) {
library(torch)
ds <- torch::tensor_dataset(x = torch_randn(100, 10), y = torch_randn(100, 1))
dl <- torch::dataloader(ds, batch_size = 32)
model <- torch::nn_linear
model <- model %>% setup(
  loss = torch::nn_mse_loss(),
  optimizer = torch::optim_adam
) %>%
  set_hparams(in_features = 10, out_features = 1)
records <- lr_finder(model, dl, verbose = FALSE)
plot(records)
}
#> Error in set_hparams(., in_features = 10, out_features = 1): impossible de trouver la fonction "set_hparams"