Predict using TabICL2

# S3 method for class 'tab_icl_v2'
predict(object, new_data, ...)

Arguments

object, x

A tab_icl_v2 object.

new_data

A data frame or matrix of new predictors.

...

Not used, but required for extensibility.

Value

predict() returns a tibble of predictions and augment() appends the columns in new_data. In either case, the number of rows in the tibble is guaranteed to be the same as the number of rows in new_data.

For regression data, the prediction is in the column .pred. For classification, the class predictions are in .pred_class and the probability estimates are in columns with the pattern .pred_{level} where level is the levels of the outcome factor vector.

Examples

# Minimal example for quick execution
car_split <- rsample::initial_split(mtcars[ 1:6,   ])

if (FALSE) { # \dontrun{
# Fit
if (torch_is_installed() & interactive()) {
 mod <- tab_icl2(mpg ~ cyl + log(drat), car_split)

 # Predict
 predict(mod, testing(car_split))
 augment(mod, testing(car_split))
}
} # }