Constructs DeepLabV3 semantic segmentation models with a ResNet backbone as described in Rethinking Atrous Convolution for Semantic Image Segmentation. These models employ atrous spatial pyramid pooling to capture multi-scale context.

model_deeplabv3_resnet50(
  pretrained = FALSE,
  progress = TRUE,
  num_classes = 21,
  aux_loss = NULL,
  pretrained_backbone = FALSE,
  ...
)

model_deeplabv3_resnet101(
  pretrained = FALSE,
  progress = TRUE,
  num_classes = 21,
  aux_loss = NULL,
  pretrained_backbone = FALSE,
  ...
)

Arguments

pretrained

(bool): If TRUE, returns a model pre-trained on ImageNet.

progress

(bool): If TRUE, displays a progress bar of the download to stderr.

num_classes

Number of output classes.

aux_loss

Logical or NULL. If TRUE, includes an auxiliary classifier branch. If NULL (default), the presence of aux classifier is inferred from pretrained weights.

pretrained_backbone

If TRUE and pretrained = FALSE, loads ImageNet weights for the ResNet backbone.

...

Other parameters passed to the model implementation.

Functions

  • model_deeplabv3_resnet50(): DeepLabV3 with ResNet-50 backbone

  • model_deeplabv3_resnet101(): DeepLabV3 with ResNet-101 backbone

Task

Semantic image segmentation with 21 output classes by default (COCO).

Input Format

The models expect input tensors of shape (batch_size, 3, H, W). Typical training uses 520x520 images.

See also

Other semantic_segmentation_model: model_fcn_resnet

Examples

if (FALSE) { # \dontrun{
library(magrittr)
norm_mean <- c(0.485, 0.456, 0.406) # ImageNet normalization constants, see
# https://pytorch.org/vision/stable/models.html
norm_std  <- c(0.229, 0.224, 0.225)
# Use a publicly available image of an animal
wmc <- "https://upload.wikimedia.org/wikipedia/commons/thumb/"
url <- "e/ea/Morsan_Normande_vache.jpg/120px-Morsan_Normande_vache.jpg"
img <- base_loader(paste0(wmc,url))

input <- img %>%
  transform_to_tensor() %>%
  transform_resize(c(520, 520)) %>%
 transform_normalize(norm_mean, norm_std)
batch <- input$unsqueeze(1)    # Add batch dimension (1, 3, H, W)

# DeepLabV3 with ResNet-50
model <- model_deeplabv3_resnet50(pretrained = TRUE)
model$eval()
output <- model(batch)

# visualize the result
# `draw_segmentation_masks()` turns the torch_float output into a boolean mask internaly:
segmented <- draw_segmentation_masks(input, output$out$squeeze(1))
tensor_image_display(segmented)

# Show most frequent class
mask_id <- output$out$argmax(dim = 2)  # (1, H, W)
class_contingency_with_background <- mask_id$view(-1)$bincount()
class_contingency_with_background[1] <- 0L # we clean the counter for background class id 1
top_class_index <- class_contingency_with_background$argmax()$item()
cli::cli_inform("Majority class {.pkg ResNet-50}: {.emph {model$classes[top_class_index]}}")

# DeepLabV3 with ResNet-101 (same steps)
model <- model_deeplabv3_resnet101(pretrained = TRUE)
model$eval()
output <- model(batch)

segmented <- draw_segmentation_masks(input, output$out$squeeze(1))
tensor_image_display(segmented)

mask_id <- output$out$argmax(dim = 2)
class_contingency_with_background <- mask_id$view(-1)$bincount()
class_contingency_with_background[1] <- 0L # we clean the counter for background class id 1
top_class_index <- class_contingency_with_background$argmax()$item()
cli::cli_inform("Majority class {.pkg ResNet-101}: {.emph {model$classes[top_class_index]}}")
} # }