Tree-Based Machine Learning Methods for Prediction, Variable Selection and Causal Inference

Part III

Hemant Ishwaran and Min Lu
University of Miami

Variable selection

We will discuss 4 different methods

  1. Permutation Variable Importance (VIMP)
  2. Subsampling for Confidence Intervals
  3. Minimal Depth
  4. VarPro

Permutation importance

Idea

In the OOB cases for a tree, randomly permute all values of the \(j\)th variable. Put these new covariate values down the tree and compute a new internal error rate. The amount by which this new error exceeds the original OOB error is defined as the importance of the \(j\)th variable for the tree. Averaging over the forest yields VIMP.

— Measure 1 (Manual On Setting Up, Using, And Understanding Random Forests V3.1

OOB explanation

The tree OOB error rate is \[ \frac{1}{\# O_{ib}}\sum_{i\in O_{ib}} L(Y_i,h^*_b(\mathbf{x}_i)) \]

Permute (perturb) the \(j\) covariate feature for \(i\) to obtain \(\mathbf{x}^{*(j)}_i\)

The tree VIMP is therefore \[ \frac{1}{\# O_{ib}}\sum_{i\in O_{ib}} L(Y_i,h^*_b(\mathbf{x}^{*(j)}_i)) - \frac{1}{\# O_{ib}}\sum_{i\in O_{ib}} L(Y_i,h^*_b(\mathbf{x}_i)) \]

Averaging over trees gives VIMP. Large values identify important variables (negative values can occur!)

OOB explanation

Different VIMP in the package

importance = c("anti", "permute", "random")



importance = TRUE       -->   anti-VIMP
importance = "permute"  -->   Breiman-Cutler
importance = "random"   -->   random-VIMP

Obtaining VIMP using the package

During training:

rfsrc(mpg~., importance="permute")$importance 
rfsrc(mpg~., importance="permute", block.size=10)$importance 

Using restore:

predict(obj, importance="permute")$importance 
predict(obj, importance="permute", block.size=10)$importance  

Using vimp:

vimp(obj, importance="permute") 
vimp(obj, importance="permute", block.size=10)$importance  

vimp also permits joint VIMP

## paired permutation vimp
obj <- rfsrc(Species~., data=iris)
vimp(obj, obj.xvar.names[1:2], importance="permute", joint=TRUE)$importance
vimp(obj, obj.xvar.names[3:4], importance="permute", joint=TRUE)$importance

General call to vimp

## VIMP for all variables 
iris.obj <- rfsrc(Species ~ ., data = iris)
print(vimp(iris.obj)$importance)
>                     all     setosa versicolor   virginica
> Sepal.Length 0.06857263 0.09785815 0.42079003 0.034794007
> Sepal.Width  0.02520082 0.11525515 0.08481039 0.003261938
> Petal.Length 0.55673198 1.44395131 1.80602645 1.241711139
> Petal.Width  0.60811332 1.75601006 2.00391736 1.144940306

## joint VIMP 
print(vimp(iris.obj, c("Petal.Length", "Petal.Width"), joint = TRUE)$importance)
>             all   setosa versicolor virginica
> joint 0.9336724 2.593241   2.447541  2.489946

VIMP illustration using peakVO2

We have \(n=2231\) cardiovascular patients with systolic heart failure and all underwent cardiopulmonary stress testing. The outcome is all cause mortality (mean follow-up of 5 years, 742 patients died). Baseline characteristics and exercise stress test results were recorded (\(p=39\)).

VIMP illustration using peakVO2

Confidence intervals for VIMP

VIMP standard error and CI are obtained using subsampling

Subsampling samples data without replacement where the subsample size \(m\) is substantially smaller than \(n\); eg: \[ m=\sqrt{n}=\sqrt{2231}\approx 48\ll n=2231 \]

CI are obtained using normal approximations \[ \texttt{vimp} \pm z_{\alpha/2} \hat{\sigma} \] where \(\hat{\sigma}\) is the subsampled standard error estimator[2, 3]

## example using peakVO2
data(peakVO2, package = "randomForestSRC")
o <- rfsrc(Surv(ttodead, died)~., peakVO2, importance="permute")
oo <- subsample(o)
plot.vimp.ci(oo, alpha=.05)

VIMP for regression: Iowa housing

VIMP for classification example: Glioma

General call to subsample

Minimal depth

Measures importance of a variable by how close it splits to the root

Pros:

  1. Much faster
  2. Works in all settings for any type of tree
  3. Doesn’t depend on prediction error

Cons:

  1. Threshold value can be sensitive and relies on assumptions

Minimal depth

General call to max.subtree

Minimal depth illustration using peakVO2

md <- max.subtree(o)$order[, 1]
barplot(sort(md), 
        las=2, horiz = TRUE, col = "cadetblue3")

Minimal depth illustration using peakVO2

md <- max.subtree(o)$order[, 1]
barplot(sort(md), 
        las=2, horiz = TRUE, col = "cadetblue3")


## guide random feature selection with number of times variable splits
xvar.used <- predict(o, 
               var.used="all.trees")$var.used
os <- rfsrc(Surv(ttodead, died)~., peakVO2, xvar.wt = xvar.used)
mds <- max.subtree(os)$order[, 1]
barplot(sort(mds), 
        las=2, horiz = TRUE, col = "cadetblue3")

VarPro

Permutation VIMP creates “artificial” data which can impact performance in correlated settings

VarPro

Permutation VIMP creates “artificial” data which can impact performance in correlated settings

Consider a patient in the peakVO2 dataset

age betablok dilver nifed acei angioten.II anti.arrhy anti.coag aspirin digoxin nitrates vasodilator diuretic.loop diuretic.thiazide diuretic.potassium.spar lipidrx.statin insulin surgery.pacemaker surgery.cabg surgery.pci surgery.aicd.implant resting.systolic.bp resting.hr smknow q.wave.mi bmi niddm lvef.metabl peak.rer peak.vo2 interval cad died ttodead bun sodium hgb glucose male black crcl
61 1 0 1 1 0 0 0 1 0 1 0 0 1 0 1 0 0 1 1 0 100 75 0 1 28.68956 0 30 0.97 4.2 360 1 0 2.354552 20 141 14.3 90 1 0 71.24107
bun interval peak.vo2
20 360 4.2

Suppose peak.vo2 is permuted for calculating VIMP

age betablok dilver nifed acei angioten.II anti.arrhy anti.coag aspirin digoxin nitrates vasodilator diuretic.loop diuretic.thiazide diuretic.potassium.spar lipidrx.statin insulin surgery.pacemaker surgery.cabg surgery.pci surgery.aicd.implant resting.systolic.bp resting.hr smknow q.wave.mi bmi niddm lvef.metabl peak.rer peak.vo2 interval cad died ttodead bun sodium hgb glucose male black crcl
61 1 0 1 1 0 0 0 1 0 1 0 0 1 0 1 0 0 1 1 0 100 75 0 1 28.68956 0 30 0.97 43.8 360 1 0 2.354552 20 141 14.3 90 1 0 71.24107
bun interval peak.vo2
20 360 43.8
> summary(peakVO2[,c("bun","interval", "peak.vo2")])
      bun             interval         peak.vo2    
 Min.   :  4.322   Min.   :  21.0   Min.   : 4.20  
 1st Qu.: 17.000   1st Qu.: 345.0   1st Qu.:12.80  
 Median : 23.000   Median : 480.0   Median :15.70  
 Mean   : 25.278   Mean   : 503.3   Mean   :16.27  
 3rd Qu.: 29.334   3rd Qu.: 641.0   3rd Qu.:19.30  
 Max.   :129.000   Max.   :1415.0   Max.   :43.80  

Permuting the data can create implausible values

VarPro

Permutation VIMP creates “artificial” data which can impact performance in correlated settings


VarPro works directly with the observed data creating a test statistic


For each tree branch (a rule), the rule is released along the coordinate \(j\). Importance equals the difference between the test statistic for the original rule to the released rule

VarPro: pros and cons

Pros:

  1. Much faster
  2. Much better for correlated problems
  3. User specified test statistics
  4. Doesn’t depend on prediction error
  5. Tree guided rules can be specified by the user
  6. Sparsity property

Cons:

  1. Importance value does not have an intuitive interpretation

General call to varpro

VarPro canonical illustration

## varpro canonical call
o <-varpro(Surv(ttodead, died)~., peakVO2)
importance(o)

## cv.varpro canonical call
o.cv <- cv.varpro(Surv(ttodead, died)~., peakVO2)
o.cv
> importance(o)
                            mean          std          z    zcenter selected
bun                 0.4978149504 0.0337152554 14.7652730 12.7652730        1
interval            0.3185356007 0.0295842524 10.7670661  8.7670661        1
peak.vo2            0.4312874502 0.0402880827 10.7050875  8.7050875        1
male                0.3505938059 0.0538524448  6.5102672  4.5102672        1
age                 0.2446044542 0.0396020571  6.1765593  4.1765593        1
betablok            0.2686558074 0.0479230810  5.6059794  3.6059794        1
sodium              0.2528347809 0.0512493678  4.9334224  2.9334224        1
crcl                0.1797879324 0.0437915811  4.1055365  2.1055365        1
lvef.metabl         0.2381646158 0.0611984740  3.8916757  1.8916757        1
hgb                 0.1725410259 0.0542089828  3.1828862  1.1828862        1
resting.systolic.bp 0.1968886781 0.0678923150  2.9000142  0.9000142        1
resting.hr          0.1597187409 0.0691719891  2.3090089  0.3090089        1
digoxin             0.0775483548 0.1211743434  0.6399734 -1.3600266        0
peak.rer            0.0398680722 0.0692507796  0.5757058 -1.4242942        0
glucose             0.0290473617 0.0652009690  0.4455051 -1.5544949        0
diuretic.thiazide   0.0002443196 0.0007146353  0.3418801 -1.6581199        0
aspirin             0.0000000000 0.0000000000        NaN        NaN        0
bmi                 0.0000000000 0.0000000000        NaN        NaN        0
cad                 0.0000000000 0.0000000000        NaN        NaN        0
insulin             0.0000000000 0.0000000000        NaN        NaN        0

VarPro canonical illustration

## cv.varpro canonical call
o.cv <- cv.varpro(Surv(ttodead, died)~., peakVO2)
o.cv
> o.cv
$imp
      variable         z
1          bun 12.765690
2     interval  9.084990
3     peak.vo2  8.284743
4         male  5.918617
5          age  5.537937
6     betablok  3.600902
7       sodium  3.408508
8  lvef.metabl  2.776717
9   resting.hr  2.590353
10        crcl  2.476236
11     digoxin  1.289895

$imp.conserve
      variable         z
1          bun 12.765690
2     interval  9.084990
3     peak.vo2  8.284743
4         male  5.918617
5          age  5.537937
6     betablok  3.600902
7       sodium  3.408508
8  lvef.metabl  2.776717
9   resting.hr  2.590353
10        crcl  2.476236

$imp.liberal
              variable          z
1                  bun 12.7656902
2             interval  9.0849901
3             peak.vo2  8.2847426
4                 male  5.9186167
5                  age  5.5379366
6             betablok  3.6009023
7               sodium  3.4085085
8          lvef.metabl  2.7767175
9           resting.hr  2.5903526
10                crcl  2.4762361
11             digoxin  1.2898949
12 resting.systolic.bp  1.2290245
13                 hgb  0.7475712
14   diuretic.thiazide  0.4964589

$err
          zcut nvar       err          sd
[1,] 0.1000000   14 0.3120663 0.007186951
[2,] 0.5265306   13 0.3137940 0.004717065
[3,] 0.7591837   12 0.3141295 0.004068593
[4,] 1.2632653   11 0.3099739 0.003270832
[5,] 1.3020408   10 0.3123732 0.004680203

$zcut
[1] 1.263265

$zcut.conserve
[1] 1.302041

$zcut.liberal
[1] 0.1

VarPro canonical illustration

VIMP

o <- rfsrc(Surv(ttodead, died)~., peakVO2, importance="permute")
oo <- subsample(o)
plot.vimp.ci(oo, alpha=.05)

VarPro

o.cv <- cv.varpro(Surv(ttodead, died)~., peakVO2)
barplot(o.cv$imp.liberal$z, names.arg=o.cv$imp.liberal$variable,
        las=2, horiz = TRUE, col = "coral2")

VarPro canonical illustration

VIMP

o <- rfsrc(Surv(ttodead, died)~., peakVO2, importance="permute")
oo <- subsample(o)
plot.vimp.ci(oo, alpha=.05)

VarPro

o.cv <- cv.varpro(Surv(ttodead, died)~., peakVO2)
barplot(o.cv$imp.liberal$z, names.arg=o.cv$imp.liberal$variable,
        las=2, horiz = TRUE, col = "coral2")

VarPro high-dimensional example

van de Vijver Microarray Breast Cancer

Gene expression profiling for predicting clinical outcome of breast cancer [5]. Microarray breast cancer data set of 4707 expression values on 78 patients with survival information

Time Censoring AA555029_RC AA598803_RC AB002301 AB002308 AB002331 AB002351 AB002445 AB002448 AB004064 AB004857 AB006625 AB006628 AB006746 AB007458 AB007855 AB007857 AB007883 AB007888
12.53 0 -0.5049331 -0.2425008 -0.199315682 0.90024251 0.6311663 -0.6012690 0.44181645 -0.26575425 -2.80370736 0.13287713 -1.1427432 -0.52818656 -0.97000301 0.32222703 0.41856295 1.1294556 0.4218849 -0.16941833
6.44 0 -0.5879813 0.4384945 -0.621200562 0.09301399 0.9500715 -1.4849019 0.21924725 0.14616483 0.08637013 0.59130323 1.1859283 -0.70424873 -0.62120056 0.33883667 0.05979471 0.1129456 0.4484603 -0.05979471
10.66 0 -0.3521244 -0.2258911 0.006643856 0.13952097 -0.7275022 -1.0430855 -0.27572003 -0.49496728 0.56140584 0.71089262 1.2424011 -0.77068734 1.12613368 -0.38534367 -0.23253496 -0.6079128 -0.6909611 -1.13609946
13.00 0 -0.4750357 0.5016111 -0.671029449 0.76404345 -0.9234960 -1.9267184 -0.01993157 0.09301399 0.41524100 0.24250075 -0.9002425 0.07972627 -0.43517259 0.05315085 -1.56795001 -0.9102083 -0.2823639 -0.12623326
11.98 0 -0.1660964 0.1361991 0.989934564 0.62452251 0.9168522 -1.4550045 0.07972627 0.77733117 0.19267184 -0.37205595 -0.9700030 -0.41856295 -0.07640435 -0.20263761 0.54147428 -0.1926718 0.4584261 0.71089262
11.16 0 -0.8935987 -0.1361991 0.438494503 0.56140584 -0.4551041 -1.5679500 -0.42852873 0.28568581 -0.25911039 0.47835764 2.2356577 -0.59130323 -0.83712590 0.48500150 -0.03321928 1.2457230 -0.1195894 -0.27239811
10.14 0 -0.4916454 -0.5746936 -0.289007753 0.16941833 0.9135302 -0.1129456 0.20928147 -0.47171378 0.52818656 0.03321928 -0.8138724 -0.68431717 -0.52486461 0.07640435 0.69096106 1.3686343 -0.1959938 -0.19599375
8.80 0 -0.4650699 -0.2956516 -0.385343671 -0.26907617 -0.3587682 -1.4815799 -0.97664684 -0.87366706 0.60126901 -0.42188486 -1.6543202 -0.57137161 1.12613368 -0.57137161 -0.98661262 -1.3088397 -1.0497292 -0.52486461
data(vdv, package = "randomForestSRC")  
dim(vdv)  
> [1] 78  4707

VarPro high-dimensional example

## van de Vijver Microarray Breast Cancer
## high dimensional survival example using different split-weights
## illustrates guided trees 

data(vdv, package = "randomForestSRC")
f <- as.formula(Surv(Time, Censoring)~.)
     
## lasso only
importance(varpro(f, vdv, split.weight.method = "lasso"))

## lasso and vimp
importance(varpro(f, vdv, split.weight.method = "lasso vimp"))

## lasso, vimp and shallow trees
importance(varpro(f, vdv, split.weight.method = "lasso vimp tree"))

## store the original vdv 70 gene signature in object nms
## compare methods using 25 runs:
rO <- lapply(1:25, function(b) {
  cat("replication:", b, "\n")
  o1 <- varpro(f, vdv, split.weight.method = "lasso")
  o2 <- varpro(f, vdv, split.weight.method = "lasso vimp")
  o3 <- varpro(f, vdv, split.weight.method = "lasso vimp tree")
  o4 <- varpro(f, vdv, split.weight.method = "lasso vimp", sparse = FALSE)
  list("lasso"=intersect(nms,get.orgvimp(o1)$variable),
       "lasso.vimp"=intersect(nms,get.orgvimp(o2)$variable),
       "lasso.vimp.tree"=intersect(nms,get.orgvimp(o3)$variable),
       "lasso.vimp.sparseoff"=intersect(nms,get.orgvimp(o4)$variable))
})

VarPro high-dimensional example


data(vdv, package = "randomForestSRC")
f <- as.formula(Surv(Time, Censoring)~.)
     
## lasso only
importance(varpro(f, vdv, split.weight.method = "lasso"))

## lasso and vimp
importance(varpro(f, vdv, split.weight.method = "lasso vimp"))

## lasso, vimp and shallow trees
importance(varpro(f, vdv, split.weight.method = "lasso vimp tree"))

## store the original vdv 70 gene signature in object nms
## compare methods using 25 runs:
rO <- lapply(1:25, function(b) {
  cat("replication:", b, "\n")
  o1 <- varpro(f, vdv, split.weight.method = "lasso")
  o2 <- varpro(f, vdv, split.weight.method = "lasso vimp")
  o3 <- varpro(f, vdv, split.weight.method = "lasso vimp tree")
  o4 <- varpro(f, vdv, split.weight.method = "lasso vimp", sparse = FALSE)
  list("lasso"=intersect(nms,get.orgvimp(o1)$variable),
       "lasso.vimp"=intersect(nms,get.orgvimp(o2)$variable),
       "lasso.vimp.tree"=intersect(nms,get.orgvimp(o3)$variable),
       "lasso.vimp.sparseoff"=intersect(nms,get.orgvimp(o4)$variable))
})

Intersection with VDV 70 gene signature

$lasso
[1] "AF201951"       "AL080059"      
[3] "Contig25991"    "Contig28552_RC"
[5] "NM_000436"      "NM_003748"     
[7] "NM_005915"      "NM_006681"     
[9] "NM_020974"     

$lasso.vimp
 [1] "AL137718"       "Contig25991"   
 [3] "Contig28552_RC" "Contig51464_RC"
 [5] "Contig55377_RC" "NM_000436"     
 [7] "NM_003239"      "NM_003748"     
 [9] "NM_005915"      "NM_006681"     
[11] "NM_016448"      "NM_020974"     

$lasso.vimp.tree
 [1] "AA555029_RC"    "AF201951"      
 [3] "AL080059"       "Contig25991"   
 [5] "Contig28552_RC" "Contig55377_RC"
 [7] "NM_000436"      "NM_003748"     
 [9] "NM_005915"      "NM_006117"     
[11] "NM_006681"      "NM_016448"     
[13] "NM_020974"     

$lasso.vimp.sparseoff
 [1] "AF201951"       "AF257175"      
 [3] "AL080059"       "AL137718"      
 [5] "Contig25991"    "Contig28552_RC"
 [7] "Contig40831_RC" "Contig48328_RC"
 [9] "Contig51464_RC" "Contig55377_RC"
[11] "Contig63102_RC" "NM_002916"     
[13] "NM_003239"      "NM_003748"     
[15] "NM_005915"      "NM_006117"     
[17] "NM_006681"      "NM_016448"     
[19] "NM_020974"     

Outline

Part I: Training

  1. Quick start
  2. Data structures allowed
  3. Training (grow) with examples
    (regression, classification, survival)

Part II: Inference and Prediction

  1. Inference (OOB)
  2. Prediction Error
  3. Prediction
  4. Restore
  5. Partial Plots

Part III: Variable Selection

  1. VIMP
  2. Subsampling (Confidence Intervals)
  3. Minimal Depth
  4. VarPro

Part IV: Advanced Examples

  1. Class Imbalanced Data
  2. Competing Risks
  3. Multivariate Forests
  4. Missing data imputation

References

1. Ishwaran H, Lu M, Kogalur UB. randomForestSRC: Variable importance (VIMP) with subsampling inference vignette. 2021. http://randomforestsrc.org/articles/vimp.html.
2. Ishwaran H, Lu M. Standard errors and confidence intervals for variable importance in random forest regression, classification, and survival. Statistics in medicine. 2019;38:558–82. https://ishwaran.org/papers/IL.StatMed.2019.pdf.
3. Politis DN, Romano JP. Large sample confidence regions based on subsamples under minimal assumptions. The Annals of Statistics. 1994;22:2031–50.
4. Ishwaran H, Chen X, Minn AJ, Lu M, Lauer MS, Kogalur UB. randomForestSRC: Minimal depth vignette. 2021. http://randomforestsrc.org/articles/minidep.html.
5. Van’t Veer LJ, Dai H, Van De Vijver MJ, He YD, Hart AA, Mao M, et al. Gene expression profiling predicts clinical outcome of breast cancer. Nature. 2002;415:530–6.