library('randomForest')
## randomForest 4.6-10
## Type rfNews() to see new features/changes/bug fixes.
library('rpart')
library('ggplot2')
library('parallel')
library('WVPlots') # https://github.com/WinVector/WVPlots
## Loading required package: grid
## Loading required package: gridExtra
## Loading required package: reshape2
## Loading required package: ROCR
## Loading required package: gplots
## 
## Attaching package: 'gplots'
## 
## The following object is masked from 'package:stats':
## 
##     lowess
## 
## Loading required package: plyr
## Loading required package: stringr
set.seed(2362667L)

#' build a concept hard for decision stumps
#' easy for sums of decision stumps
#' @param nRow number of rows to generate, positive integer
#' @param nCol number of columns, positive integer >3
mkExampleRF1 <- function(nRow,nCol) {
  nSpecial <- 3
  dList <- lapply(seq_len(nCol),function(i) rnorm(nRow))
  names(dList) <- paste('v',seq_len(nCol),sep='_')
  d <- data.frame(dList)
  # compute outcome
  repeat {
     yScore <- rowSums(d)
     # add in more copies of special vars
     for(v in seq_len(nSpecial)) {
        yScore <- yScore + d[[v]]
     }
     y <- as.factor(as.character((yScore + rnorm(nRow))>0))
     if(length(unique(y))>1) {
       break
     }
  }
  vars <- colnames(d)
  # duplicate our most usefull variables (to defeat randomForst variable selection)
  for(v in seq_len(nSpecial)) {
    dList <- lapply(seq_len(3),function(i) d[[v]])
    names(dList) <- paste('vdup',v,seq_len(length(dList)),sep='_')
    d <- cbind(d,data.frame(dList))
  }
  varsd <- colnames(d)
  d$y <- y
  list(d=d,vars=vars,varsd=varsd)
}

scorePred <- function(truth,pred) {
  t <- table(truth=truth,pred=pred)
  (t[1,1]+t[2,2])/sum(t)
}

nCol <- 100



runTest <- function(nRow) {
  res <- c()
  for(rep in 1:200) {
    trainStuff <- mkExampleRF1(nRow,nCol)
    trainData <- trainStuff$d
    vars <- trainStuff$vars
    varsd <- trainStuff$varsd
    testData <- mkExampleRF1(10000,nCol)$d
    
    form1 <- as.formula(paste('y',paste(vars,collapse = ' + '),sep=' ~ '))
    model <- randomForest(form1,data=trainData,maxnodes=3)
    #table(truth=trainData$y,pred=predict(model,newdata=trainData))
    #table(truth=testData$y,pred=predict(model,newdata=testData))
    si <- scorePred(truth=testData$y,pred=predict(model,newdata=testData))
    res <- rbind(res,data.frame(model='RF3',nTrain=nRow,accuracy=si,
                                stringsAsFactors = FALSE))

    model <- glm(form1,data=trainData,
                 family=binomial(link='logit'))
    pred=predict(model,newdata=testData,type='response')
    si <- scorePred(truth=testData$y,
                    pred=pred>0.5)
    res <- rbind(res,data.frame(model='LR',nTrain=nRow,accuracy=si,
                                stringsAsFactors = FALSE))
    
#     formd <- as.formula(paste('y',paste(varsd,collapse = ' + '),sep=' ~ '))
#     model <- randomForest(formd,data=trainData,maxnodes=3)
#     #table(truth=trainData$y,pred=predict(model,newdata=trainData))
#     #table(truth=testData$y,pred=predict(model,newdata=testData))
#     si <- scorePred(truth=testData$y,pred=predict(model,newdata=testData))
#     res <- rbind(res,data.frame(model='RF3d',nTrain=nRow,accuracy=si,
#                                 stringsAsFactors = FALSE))
    
  }
  res
}


libs <- c('randomForest','rpart')
names <- c('mkExampleRF1','nCol','runTest','scorePred')
mkWorker <- function(fn,names,libs) {
  force(fn)
  force(names)
  force(libs)
  f <- function(nRow) {
    for(li in libs) {
      library(li,character.only = TRUE)
    }
    for(ni in names) {
      assign(ni,get(ni),envir=environment(fn))
    }
    fn(nRow)
  }
  for(ni in names) {
    assign(ni,get(ni),envir=environment(f))
  }
  f
}

workVec <- seq(25,400,25)

cl <- c()


cl <- parallel::makeCluster(4)



if(is.null(cl)) {
  res <- lapply(workVec,runTest)
} else {
  #names <- c("mkExampleRF1","nCol","runTest","scorePred")
  worker <- mkWorker(runTest,names,libs)
  res <- parallel::parLapply(cl,workVec,worker)
}
res <- do.call(rbind,res)

if(!is.null(cl)) {
   parallel::stopCluster(cl)
   cl <- NULL
}

res$nTrain <- as.integer(res$nTrain)

ggplot(data=res,aes(x=nTrain,y=accuracy,color=model)) +
  geom_point() + geom_smooth()
## geom_smooth: method="auto" and size of largest group is >=1000, so using gam with formula: y ~ s(x, bs = "cs"). Use 'method = x' to change the smoothing method.

for(model in unique(res$model)) {
  rm <- res[res$model==model,,drop=FALSE]
  agg <- aggregate(accuracy~nTrain,data=rm,FUN=median)
  #cuts <- median(agg$accuracy)
  cutPts <- c(210)
  for(cutPt in cutPts) {
    cuts <- c((min(agg[agg$nTrain<=cutPt,'accuracy']) + max(agg[agg$nTrain>cutPt,'accuracy']))/2,
              (max(agg[agg$nTrain<=cutPt,'accuracy']) + min(agg[agg$nTrain>cutPt,'accuracy']))/2)
    for(cut in cuts) {
      print(ScatterBoxPlot(rm,'nTrain','accuracy',title=model) + 
              geom_hline(yintercept=cut,color='blue') )
    }
  }
}