Estimator of the Mean Squared Prediction Error using Cross-Validation.
Source:R/crossval.R
crossval.RdEstimator of the mean squared prediction error of different learners using cross-validation.
Usage
crossval(
y,
X,
Z = NULL,
learners,
cv_folds = 5,
cv_subsamples = NULL,
silent = FALSE,
progress = NULL
)Arguments
- y
The outcome variable.
- X
A (sparse) matrix of predictive variables.
- Z
Optional additional (sparse) matrix of predictive variables.
- learners
learnersis a list of lists, each containing four named elements:funThe base learner function. The function must be such that it predicts a named inputyusing a named inputX.argsOptional arguments to be passed tofun.assign_XAn optional vector of column indices corresponding to variables inXthat are passed to the base learner.assign_ZAn optional vector of column indices corresponding to variables inZthat are passed to the base learner.
Omission of the
argselement results in default arguments being used infun. Omission ofassign_X(and/orassign_Z) results in inclusion of all predictive variables inX(and/orZ).- cv_folds
Number of folds used for cross-validation.
- cv_subsamples
List of vectors with sample indices for cross-validation.
- silent
Boolean to silence estimation updates.
- progress
String to print before learner and cv fold progress.
Value
crossval returns a list containing the following components:
mspeA vector of MSPE estimates, each corresponding to a base learners (in chronological order).
oos_residA matrix of out-of-sample prediction errors, each column corresponding to a base learners (in chronological order).
cv_subsamplesPass-through of
cv_subsamples. See above.
See also
Other utilities:
crosspred(),
shortstacking()
Examples
# Construct variables from the included Angrist & Evans (1998) data
y = AE98[, "worked"]
X = AE98[, c("morekids", "age","agefst","black","hisp","othrace","educ")]
# Compare ols, lasso, and ridge using 4-fold cross-validation
cv_res <- crossval(y, X,
learners = list(list(fun = ols),
list(fun = mdl_glmnet),
list(fun = mdl_glmnet,
args = list(alpha = 0))),
cv_folds = 4,
silent = TRUE)
cv_res$mspe
#> [1] 0.2365091 0.2365085 0.2365032