Decision tree and automatic rule generation
Post date: Mar 3, 2014 3:41:28 AM
We can automatically create a collection of rules for any node in a decision tree. In this example I use an R package rpart to build the tree and develop a function to return a collection of rule based on the decision tree. The rpart object my.tree is given below:
> print(my.tree)
n= 891
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 891 342 0 (0.61616162 0.38383838)
2) Sex=male 577 109 0 (0.81109185 0.18890815)
4) Age>=6.5 553 93 0 (0.83182640 0.16817360) *
5) Age< 6.5 24 8 1 (0.33333333 0.66666667) *
3) Sex=female 314 81 1 (0.25796178 0.74203822)
6) Pclass>=2.5 144 72 0 (0.50000000 0.50000000)
12) Fare>=23.35 27 3 0 (0.88888889 0.11111111) *
13) Fare< 23.35 117 48 1 (0.41025641 0.58974359)
26) Embarked=S 63 31 0 (0.50793651 0.49206349)
52) Fare< 10.825 37 15 0 (0.59459459 0.40540541) *
53) Fare>=10.825 26 10 1 (0.38461538 0.61538462) *
27) Embarked=C,Q 54 16 1 (0.29629630 0.70370370) *
7) Pclass< 2.5 170 9 1 (0.05294118 0.94705882) *
For demonstrative purpose, let's extract the rule for node#5 from the tree above. At the end, the output rule would look like this:
"!(is.na(df.train$Sex)) & df.train$Sex %in% c('male') & !(is.na(df.train$Age)) & df.train$Age<6.5"
The example of the code is shown below:
rm
(list = setdiff(ls(), lsf.str())) setwd("C:/Users/kittipat/research/rpart_demo") library(rpart); load("train.RData") #names(df.train) # ========================= # -- Training the tree -- # ========================= rpart.control=rpart.control(minsplit=30, cp=0.01) my.tree <- rpart(formula=Survived~., data=df.train[, !(names(df.train) %in% c("PassengerId", "Name", "Ticket", "Cabin"))], na.action=na.pass, method='class', control=rpart.control, x=T, y=T) # -- check which variables are used (and not) in the model -- colnames(my.tree$x) # variables used in the model # ================================ # --- see the plot of the tree --- # ================================ library(partykit); png(file="plot_tree.png", width=12, height=6, units='in', res=300) plot(as.party(my.tree)) dev.off() # -- write the full tree in text -- tree_text_full <- capture.output(print(my.tree)) write(x=tree_text_full, file="tree_text_full.txt", append=F, sep=",") # =============================================== # --- apply the trained tree to a new test data --- # =============================================== load("test.RData") y.test <- predict(object=my.tree, newdata=df.test, type = "class", na.action = na.pass) table(y.test,df.test.label$survived, useNA='ifany') # ============================================================== # --- Pruning the tree --- # When you want to have a simpler tree, you prune the tree # by increasing the complexity parameter cp--the bigger cp, the simpler tree # ============================================================== my.tree.prune <- prune(tree=my.tree, cp=0.02) # -- write the prune tree in text -- tree_text_prune <- capture.output(print(my.tree.prune)) write(x=tree_text_prune, file="tree_text_prune.txt", append=F, sep=",") # ====================================== # --- Automatic rule generating --- # ====================================== # This is the function to generate a rule automatically from any node in the tree source("tree_functions.r") # Get all the rules in the path to a node expr <- path.rpart(tree=my.tree,nodes=5)[[1]] # Convert all the rules to R syntax expr.out <- autoGenerateRule(expr, df.train, df.name='df.train', na.check=T) # combine all the rules in the path into a single rule extracted.rule.text <- paste(expr.out$rule.extract, collapse=" & ") # The output extracted rule would look something like this: # "!(is.na(df.train$Sex)) & df.train$Sex %in% c('male') & !(is.na(df.train$Age)) & df.train$Age<6.5" # ================= # evaluate the rule # ================= outcome <- eval(parse(text=extracted.rule.text)) # How many observations fall into this bucket table(outcome, useNA='ifany') # outcome # FALSE TRUE # 867 24 # In that bucket, how many are positive? table(df.train$Survived[outcome], useNA='ifany') # 0 1 # 8 16 print(my.tree) # Note that this gives the same answer as appeared in the tree summary # 1) root 891 342 0 (0.61616162 0.38383838) # 2) Sex=male 577 109 0 (0.81109185 0.18890815) # 5) Age< 6.5 24 8 1 (0.33333333 0.66666667) *
Created by Pretty R at inside-R.org
And if you are interested to see the comprehensive report and rules extracted from the tree, you might want to run the part 2 of the code:
# ============================================================================= # PART 2 # ----------------------------------------------------------------------------- # Report rule and yield from every node in the tree for oth train and test set # ============================================================================= node.id.list <- as.numeric(row.names(my.tree$frame)) # list of nodes in the tree # node.id.list <- node.id.list[order(node.id.list)] # node.id.list <- setdiff(node.id.list,1) # remove the root node # ---------------------------- # prepare node dataframe report # ---------------------------- tree.df.report <- NULL # tree node number in depth-first search tree.df.report <- data.frame('node.id'=paste('node#',row.names(my.tree$frame),sep="")) tree.df.report$splitted.by <- my.tree$frame$var tree.df.report$level <- floor(log2(node.id.list)) tree.df.report$level.display <- sapply(X=tree.df.report$level, FUN=function(times) { return(paste(rep('-',times),sep="",collapse="")) }) # prepare room for all metrics tree.df.report$vol.invg.train <- NA tree.df.report$perc.yield.train <- NA tree.df.report$perc.cover.train <- NA tree.df.report$tp.cnt.train <- NA tree.df.report$fp.cnt.train <- NA tree.df.report$tn.cnt.train <- NA tree.df.report$miss.train <- NA tree.df.report$vol.invg.test <- NA tree.df.report$perc.yield.test <- NA tree.df.report$perc.cover.test <- NA tree.df.report$tp.cnt.test <- NA tree.df.report$fp.cnt.test <- NA tree.df.report$tn.cnt.test <- NA tree.df.report$miss.test <- NA # ---------------------------- # ------------------------------------------------------------ # For-loop for each node in depth-first search, skipping the root # ------------------------------------------------------------ for (i in 2:length(node.id.list)) { node.id <- node.id.list[i] expr <- path.rpart(tree=my.tree,nodes=node.id, print.it=F)[[1]] expr.out <- autoGenerateRule(expr, df.train, df.name='df.train', na.check=T) extracted.rule.text <- paste(expr.out$rule.extract, collapse=" & ") # -- evaluate the rule for train dataset -- is.fired <- eval(parse(text=extracted.rule.text)) confusion.matrix <- table(is.fired, df.train$Survived, useNA='ifany') confusion.matrix <- as.data.frame.matrix(confusion.matrix) row.names(confusion.matrix)[ row.names(confusion.matrix)=='FALSE' ] <- 'rule.not.fired' row.names(confusion.matrix)[ row.names(confusion.matrix)=='TRUE' ] <- 'rule.fired' # all necessary metrics tp.cnt <- confusion.matrix['rule.fired','1'] fp.cnt <- confusion.matrix['rule.fired','0'] tn.cnt <- confusion.matrix['rule.not.fired','0'] fn.cnt <- confusion.matrix['rule.not.fired','1'] perc.yield <- 100*tp.cnt/(tp.cnt+fp.cnt) perc.cover <- 100*tp.cnt/(tp.cnt+fn.cnt) vol.invg <- tp.cnt+fp.cnt # -- evaluate the rule for train dataset -- is.fired <- eval(parse(text=extracted.rule.text)) confusion.matrix <- table(is.fired, df.train$Survived, useNA='ifany') confusion.matrix <- as.data.frame.matrix(confusion.matrix) row.names(confusion.matrix)[ row.names(confusion.matrix)=='FALSE' ] <- 'rule.not.fired' row.names(confusion.matrix)[ row.names(confusion.matrix)=='TRUE' ] <- 'rule.fired' # all necessary metrics tp.cnt <- confusion.matrix['rule.fired','1'] fp.cnt <- confusion.matrix['rule.fired','0'] tn.cnt <- confusion.matrix['rule.not.fired','0'] fn.cnt <- confusion.matrix['rule.not.fired','1'] perc.yield.train <- 100*tp.cnt/(tp.cnt+fp.cnt) perc.cover.train <- 100*tp.cnt/(tp.cnt+fn.cnt) vol.invg.train <- tp.cnt+fp.cnt tp.cnt.train <- tp.cnt fp.cnt.train <- fp.cnt tn.cnt.train <- tn.cnt miss.train <- fn.cnt # -- evaluate the rule for test dataset -- extracted.rule.text <- gsub(pattern='df.train',replacement='df.test',x=extracted.rule.text) is.fired <- eval(parse(text=extracted.rule.text)) confusion.matrix <- table(is.fired, df.test.label$survived, useNA='ifany') confusion.matrix <- as.data.frame.matrix(confusion.matrix) row.names(confusion.matrix)[ row.names(confusion.matrix)=='FALSE' ] <- 'rule.not.fired' row.names(confusion.matrix)[ row.names(confusion.matrix)=='TRUE' ] <- 'rule.fired' # all necessary metrics tp.cnt <- confusion.matrix['rule.fired','1'] fp.cnt <- confusion.matrix['rule.fired','0'] tn.cnt <- confusion.matrix['rule.not.fired','0'] fn.cnt <- confusion.matrix['rule.not.fired','1'] perc.yield.test <- 100*tp.cnt/(tp.cnt+fp.cnt) perc.cover.test <- 100*tp.cnt/(tp.cnt+fn.cnt) vol.invg.test <- tp.cnt+fp.cnt tp.cnt.test <- tp.cnt fp.cnt.test <- fp.cnt tn.cnt.test <- tn.cnt miss.test <- fn.cnt # -- make dataframe report -- tree.df.report$vol.invg.train[i] <- vol.invg.train tree.df.report$perc.yield.train[i] <- perc.yield.train tree.df.report$perc.cover.train[i] <- perc.cover.train tree.df.report$tp.cnt.train[i] <- tp.cnt.train tree.df.report$fp.cnt.train[i] <- fp.cnt.train tree.df.report$tn.cnt.train[i] <- tn.cnt.train tree.df.report$miss.train[i] <- miss.train tree.df.report$vol.invg.test[i] <- vol.invg.test tree.df.report$perc.yield.test[i] <- perc.yield.test tree.df.report$perc.cover.test[i] <- perc.cover.test tree.df.report$tp.cnt.test[i] <- tp.cnt.test tree.df.report$fp.cnt.test[i] <- fp.cnt.test tree.df.report$tn.cnt.test[i] <- tn.cnt.test tree.df.report$miss.test[i] <- miss.test # -- report result -- if (perc.yield.train > 0 & perc.cover > 0) { capture.output( cat(sprintf("node#%i -- train set -- invg vol:%i yield:%.2f%% coverage:%.2f%% tp:%i fp:%i tn:%i miss:%i test set -- invg vol:%i yield:%.2f%% coverage:%.2f%% tp:%i fp:%i tn:%i miss:%i Rule: %s \n\n", node.id, vol.invg.train, perc.yield.train, perc.cover.train,tp.cnt.train,fp.cnt.train,tn.cnt.train,miss.train, vol.invg.test, perc.yield.test, perc.cover.test,tp.cnt.test,fp.cnt.test,tn.cnt.test,miss.test, extracted.rule.text)) ,file="all_rule_summary.txt", append=T) } } write.csv(tree.df.report, file='tree_report.csv', row.names=F)
Created by Pretty R at inside-R.org
The function is shown below
# --- This is the function to generate a collection of rules from a single path to a node in a decision tree autoGenerateRule <- function(expr, df, df.name='yourDFName', na.check=T) { # version 1.0 # Kittipat "Bot" Kampa -- kittipat AT gmail # ======================================== # expr: expression can be list of text or dataframe each element is a text # df: the dataframe object containing the variables used in the rules # df.name: The name of dataframe to appear on the extracted rule # na.check: TRUE if we want to include NA-check expression in the extracted rule, for instance, !(is.na(df.name$Age)) # ======================================== # coerce expr to data.frame if (class(expr)=="list") {expr <- expr[[1]]} # Keeping only 1 line would be enough # We will use this for determining the type of variable df <- df[1,] # parse variable name, arithmetic and values first.part <- regexpr(pattern=".*[<|>|<=|>=|=]",text=expr) first.part <- regmatches(x=expr,m=first.part) var.names <- gsub(" *[<|>|<=|>=|=] *","",first.part) # Note that the first element "root" will be automatically eliminated here # because it does not contain the symbol ><= mid.part <- regexpr(pattern="[<=|>=|<|>|=]+",text=expr) var.arith <- regmatches(x=expr, m=mid.part) last.part <- regexpr(pattern="[<|>|<=|>=|=]+.*",text=expr) last.part <- regmatches(x=expr,m=last.part) var.values <- gsub("^ *[<|>|<=|>=|=]+ *","",last.part) # type of each variable var.type <- sapply(X=df[,var.names], FUN=class) # make a more proper format for the values # I could have done a better job by avoding the for-loop and using apply instead var.values2 <- NULL var.arith2 <- NULL for (i in 1:length(var.names)) { if (var.type[i] == "factor") { var.values2[i] <- paste("c('", gsub(pattern=",",replacement="','",x=var.values[i]),"')",sep="") var.arith2[i] <- " %in% " # change "=" to "%in%" } else { var.values2[i] <- var.values[i] var.arith2[i] <- var.arith[i] } } # combine to generate rule path check.na <- paste("!(is.na(",df.name,"$",var.names,"))",sep="") if ( na.check == T ) { # rule + checking NA case rule.extract <- paste("!(is.na(",df.name,"$",var.names,")) & ",df.name,"$",var.names, var.arith2, var.values2, sep="") } else { # only the rule rule.extract <- paste(df.name,"$",var.names, var.arith2, var.values2, sep="") } out.expr <- list("var.names"=var.names, "var.arith"=var.arith, "var.values"=var.values, "var.type"=var.type, "var.arith2"=var.arith2, "var.values2"=var.values2, "rule.extract"=rule.extract, "check.na"=check.na) return(out.expr) }
Created by Pretty R at inside-R.org
Note that this version does not include logic optimization, for example, the tree can produce some redundant rule like "a > 1 & b > 6 & a > 3", which can be better represented by "a > 3 & b > 6",