knitr::opts_chunk$set(echo = TRUE)
library('vtreat')
library('WVPlots') # see: https://github.com/WinVector/WVPlots
## Loading required package: ggplot2
## Loading required package: grid
## Loading required package: gridExtra
## Loading required package: reshape2
## Loading required package: ROCR
## Loading required package: gplots
## 
## Attaching package: 'gplots'
## The following object is masked from 'package:stats':
## 
##     lowess
## Loading required package: plyr
## Loading required package: stringr
## Loading required package: mgcv
## Loading required package: nlme
## This is mgcv 1.8-12. For overview type 'help("mgcv-package")'.
library('xgboost')
ncore <- parallel::detectCores()
cl <- parallel::makeCluster(ncore)
# see: https://github.com/WinVector/PreparingDataWorkshop/tree/master/KDD2009
d = read.table('orange_small_train.data.gz',
               header=T,sep='\t',na.strings=c('NA',''), 
               strip.white = TRUE,
               stringsAsFactors=FALSE)
churn = read.table('orange_small_train_churn.labels.txt',
                   header=F,sep='\t',
                   strip.white = TRUE,
                   stringsAsFactors = FALSE)
d$churn = churn$V1
set.seed(729375)
rgroup = runif(dim(d)[[1]])
dTrain = d[rgroup<=0.9,]  # set for building models
dTest = d[rgroup>0.9,] # set for evaluation
rm(list=c('d','churn'))
outcomes = c('churn','appetency','upselling')
nonvars <- c(outcomes,'rgroup')
vars = setdiff(colnames(dTrain),
                nonvars)
yName = 'churn'
yTarget = 1
# build data treatments
set.seed(239525)

# build treatments 
trainPlan = mkCrossFrameCExperiment(dTrain,
    vars,yName,yTarget,
    smFactor=2.0, 
    parallelCluster=cl)
print(trainPlan$method)
## [1] "kwaycrossystratified"
treatmentsC = trainPlan$treatments
treatedTrainM = trainPlan$crossFrame

#kddSig = 1/nrow(treatmentsC$scoreFrame)
selvars <- setdiff(colnames(treatedTrainM),outcomes)
treatedTrainM[[yName]] = treatedTrainM[[yName]]==yTarget

treatedTest = prepare(treatmentsC,
                      dTest,
                      varRestriction=selvars,
                      pruneSig=NULL, 
                      parallelCluster=cl)
treatedTest[[yName]] = treatedTest[[yName]]==yTarget
mname <- 'predxgboost'

# simple default, production model would require hyperparameter search
goodvars <- treatmentsC$scoreFrame$varName[treatmentsC$scoreFrame$sig<1/nrow(treatmentsC$scoreFrame)]
formulaS = paste(yName,paste(goodvars,collapse=' + '),sep=' ~ ')

for(ntrees in c(50,100,200)) {
  modelxg = xgboost(data=xgb.DMatrix(as.matrix(treatedTrainM[,goodvars,drop=FALSE]),
                                     label=treatedTrainM[[yName]]),
                    objective='binary:logistic', 
                    nrounds=ntrees,
                    nthread=ncore)
  # prepare plotting frames
  treatedTrainP = treatedTrainM[, yName, drop=FALSE]
  treatedTestP = treatedTest[, yName, drop=FALSE]
  treatedTrainP[[mname]] = as.numeric(predict(modelxg,
                                              as.matrix(treatedTrainM[,goodvars,drop=FALSE])))
  treatedTestP[[mname]] = as.numeric(predict(modelxg,
                                             as.matrix(treatedTest[,goodvars,drop=FALSE])))
  print(WVPlots::ROCPlot(treatedTrainP,mname,yName,
                         paste0('prediction on train, ntree=',ntrees)))
  print(WVPlots::ROCPlot(treatedTestP,mname,yName,
                         paste0('prediction on test, ntree=',ntrees)))
}
## [0]  train-error:0.071800
## [1]  train-error:0.071866
## [2]  train-error:0.071422
## [3]  train-error:0.071844
## [4]  train-error:0.071822
## [5]  train-error:0.071711
## [6]  train-error:0.071378
## [7]  train-error:0.071111
## [8]  train-error:0.070645
## [9]  train-error:0.070556
## [10] train-error:0.070223
## [11] train-error:0.069979
## [12] train-error:0.069801
## [13] train-error:0.069712
## [14] train-error:0.068935
## [15] train-error:0.068824
## [16] train-error:0.068668
## [17] train-error:0.068602
## [18] train-error:0.068202
## [19] train-error:0.068135
## [20] train-error:0.067847
## [21] train-error:0.067447
## [22] train-error:0.067203
## [23] train-error:0.066781
## [24] train-error:0.066470
## [25] train-error:0.066203
## [26] train-error:0.066003
## [27] train-error:0.065870
## [28] train-error:0.065848
## [29] train-error:0.065737
## [30] train-error:0.065515
## [31] train-error:0.065404
## [32] train-error:0.064960
## [33] train-error:0.064804
## [34] train-error:0.064626
## [35] train-error:0.064493
## [36] train-error:0.064204
## [37] train-error:0.064005
## [38] train-error:0.063938
## [39] train-error:0.063827
## [40] train-error:0.063805
## [41] train-error:0.063405
## [42] train-error:0.063405
## [43] train-error:0.063383
## [44] train-error:0.063338
## [45] train-error:0.063205
## [46] train-error:0.062961
## [47] train-error:0.062761
## [48] train-error:0.062494
## [49] train-error:0.062317

## [0]  train-error:0.071800
## [1]  train-error:0.071866
## [2]  train-error:0.071422
## [3]  train-error:0.071844
## [4]  train-error:0.071822
## [5]  train-error:0.071711
## [6]  train-error:0.071378
## [7]  train-error:0.071111
## [8]  train-error:0.070645
## [9]  train-error:0.070556
## [10] train-error:0.070223
## [11] train-error:0.069979
## [12] train-error:0.069801
## [13] train-error:0.069712
## [14] train-error:0.068935
## [15] train-error:0.068824
## [16] train-error:0.068668
## [17] train-error:0.068602
## [18] train-error:0.068202
## [19] train-error:0.068135
## [20] train-error:0.067847
## [21] train-error:0.067447
## [22] train-error:0.067203
## [23] train-error:0.066781
## [24] train-error:0.066470
## [25] train-error:0.066203
## [26] train-error:0.066003
## [27] train-error:0.065870
## [28] train-error:0.065848
## [29] train-error:0.065737
## [30] train-error:0.065515
## [31] train-error:0.065404
## [32] train-error:0.064960
## [33] train-error:0.064804
## [34] train-error:0.064626
## [35] train-error:0.064493
## [36] train-error:0.064204
## [37] train-error:0.064005
## [38] train-error:0.063938
## [39] train-error:0.063827
## [40] train-error:0.063805
## [41] train-error:0.063405
## [42] train-error:0.063405
## [43] train-error:0.063383
## [44] train-error:0.063338
## [45] train-error:0.063205
## [46] train-error:0.062961
## [47] train-error:0.062761
## [48] train-error:0.062494
## [49] train-error:0.062317
## [50] train-error:0.062095
## [51] train-error:0.061939
## [52] train-error:0.061717
## [53] train-error:0.061451
## [54] train-error:0.061473
## [55] train-error:0.061406
## [56] train-error:0.061051
## [57] train-error:0.060807
## [58] train-error:0.060696
## [59] train-error:0.060385
## [60] train-error:0.060074
## [61] train-error:0.059718
## [62] train-error:0.059519
## [63] train-error:0.059185
## [64] train-error:0.059052
## [65] train-error:0.058897
## [66] train-error:0.058586
## [67] train-error:0.058275
## [68] train-error:0.058119
## [69] train-error:0.058031
## [70] train-error:0.057764
## [71] train-error:0.057231
## [72] train-error:0.057120
## [73] train-error:0.056742
## [74] train-error:0.056565
## [75] train-error:0.056409
## [76] train-error:0.056432
## [77] train-error:0.056321
## [78] train-error:0.055965
## [79] train-error:0.055565
## [80] train-error:0.054921
## [81] train-error:0.054544
## [82] train-error:0.054211
## [83] train-error:0.053922
## [84] train-error:0.053123
## [85] train-error:0.052967
## [86] train-error:0.052612
## [87] train-error:0.052456
## [88] train-error:0.052190
## [89] train-error:0.051901
## [90] train-error:0.051457
## [91] train-error:0.050680
## [92] train-error:0.050191
## [93] train-error:0.049791
## [94] train-error:0.049436
## [95] train-error:0.049014
## [96] train-error:0.048681
## [97] train-error:0.048525
## [98] train-error:0.048548
## [99] train-error:0.047970