Passer au contenu

Ensembles et chargeurs —————————————————-

dir <- “./mnist” # répertoire de mise en cache

Modifiez le dataset MNIST afin que la cible soit identique à l’entrée.

category: ‘basic’

# Packages ----------------------------------------------------------------
library(torch)
library(torchvision)
library(luz)

# Jeu de données et chargeurs de données ----------------------------------------------------

dir <- "./mnist" # répertoire de mise en cache

# Modification du jeu de données MNIST afin que la cible soit identique à l'entrée.
mnist_dataset2 <- torch::dataset(
  inherit = mnist_dataset,
  .getitem = function(i) {
    output <- super$.getitem(i)
    output$y <- output$x
    output
  }
)

train_ds <- mnist_dataset2(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)

test_ds <- mnist_dataset2(
  dir,
  train = FALSE,
  transform = transform_to_tensor
)

train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)

# Construction du réseau ---------------------------------------------------

net <- nn_module(
  "Net",
  initialize = function() {
    self$encoder <- nn_sequential(
      nn_conv2d(1, 6, kernel_size=5),
      nn_relu(),
      nn_conv2d(6, 16, kernel_size=5),
      nn_relu()
    )
    self$decoder <- nn_sequential(
      nn_conv_transpose2d(16, 6, kernel_size = 5),
      nn_relu(),
      nn_conv_transpose2d(6, 1, kernel_size = 5),
      nn_sigmoid()
    )
  },
  forward = function(x) {
    x %>%
      self$encoder() %>%
      self$decoder()
  },
  predict = function(x) {
    self$encoder(x) %>%
      torch_flatten(start_dim = 2)
  }
)

# Entraînement du modèle -------------------------------------------------------------------

fitted <- net %>%
  setup(
    loss = nn_mse_loss(),
    optimizer = optim_adam
  ) %>%
  fit(train_dl, epochs = 1, valid_data = test_dl)

# Inférence ------------------------------------------------------

preds <- predict(fitted, test_dl)

# Sérialisation ---------------------------------------------------------------

luz_save(fitted, "mnist-autoencoder.pt")