TabICL2predict.tab_icl_v2.RdPredict using TabICL2
# S3 method for class 'tab_icl_v2'
predict(object, new_data, ...)predict() returns a tibble of predictions and augment() appends the
columns in new_data. In either case, the number of rows in the tibble is
guaranteed to be the same as the number of rows in new_data.
For regression data, the prediction is in the column .pred. For
classification, the class predictions are in .pred_class and the
probability estimates are in columns with the pattern .pred_{level} where
level is the levels of the outcome factor vector.
# Minimal example for quick execution
car_split <- rsample::initial_split(mtcars[ 1:6, ])
if (FALSE) { # \dontrun{
# Fit
if (torch_is_installed() & interactive()) {
mod <- tab_icl2(mpg ~ cyl + log(drat), car_split)
# Predict
predict(mod, testing(car_split))
augment(mod, testing(car_split))
}
} # }