Classification Trees

September 18th, 2010

Decision trees are applied to situation where data is divided into groups rather than investigating a numerical response and its relationship to a set of descriptor variables. There are various implementations of classification trees in R and the some commonly used functions are rpart and tree.


Fast Tube by Casper

To illustrate the use of the tree function we will use a set of data from the UCI Machine Learning Repository where the objective of the study using this data was to predict the cellular localization sites of proteins.

The data provided on the website is shown here:

> ecoli.df = read.csv("ecoli.txt")
> head(ecoli.df)
    Sequence  mcv  gvh  lip chg  aac alm1 alm2 class
1  AAT_ECOLI 0.49 0.29 0.48 0.5 0.56 0.24 0.35    cp
2 ACEA_ECOLI 0.07 0.40 0.48 0.5 0.54 0.35 0.44    cp
3 ACEK_ECOLI 0.56 0.40 0.48 0.5 0.49 0.37 0.46    cp
4 ACKA_ECOLI 0.59 0.49 0.48 0.5 0.52 0.45 0.36    cp
5  ADI_ECOLI 0.23 0.32 0.48 0.5 0.55 0.25 0.35    cp
6 ALKH_ECOLI 0.67 0.39 0.48 0.5 0.36 0.38 0.46    cp

We can use the xtabs function to summarise the number of cases in each class.

> xtabs( ~ class, data = ecoli.df)
class
 cp  im imL imS imU  om omL  pp 
143  77   2   2  35  20   5  52

As noted in the comments the package that I used was the tree package:

> require(tree)

The complete classification tree using all variables is fitted to the data initially and then we will try to prune the tree to make it smaller.

> ecoli.tree1 = tree(class ~ mcv + gvh + lip + chg + aac + alm1 + alm2,
  data = ecoli.df)
> summary(ecoli.tree1)
 
Classification tree:
tree(formula = class ~ mcv + gvh + lip + chg + aac + alm1 + alm2, 
    data = ecoli.df)
Variables actually used in tree construction:
[1] "alm1" "mcv"  "gvh"  "aac"  "alm2"
Number of terminal nodes:  10 
Residual mean deviance:  0.7547 = 246 / 326 
Misclassification error rate: 0.122 = 41 / 336

The tree function is used in a similar way to other modelling functions in R. The misclassification rate is shown as part of the summary of the tree. This tree can be plotted and annotated with these commands:

> plot(ecoli.tree1)
> text(ecoli.tree1, all = T)

To prune the tree we use cross-validation to identify the point to prune.

> cv.tree(ecoli.tree1)
$size
 [1] 10  9  8  7  6  5  4  3  2  1
 
$dev
 [1]  463.6820  457.4463  447.9824  441.8617  455.8318  478.9234  533.5856  586.2820  713.2992 1040.3878
 
$k
 [1]      -Inf  12.16500  15.60004  19.21572  34.29868  41.10627  50.57044  64.05494 180.78800 355.67747
 
$method
[1] "deviance"
 
attr(,"class")
[1] "prune"         "tree.sequence"

This suggests a tree size of 6 and we can re-fit the tree:

> ecoli.tree2 = prune.misclass(ecoli.tree1, best = 6)
> summary(ecoli.tree2)
 
Classification tree:
snip.tree(tree = ecoli.tree1, nodes = c(4, 20, 7))
Variables actually used in tree construction:
[1] "alm1" "mcv"  "aac"  "gvh" 
Number of terminal nodes:  6 
Residual mean deviance:  0.9918 = 327.3 / 330 
Misclassification error rate: 0.1548 = 52 / 336

The misclassification rate has increased but not substantially with the pruning of the tree.

Other useful resources are provided on the Supplementary Material page.

Data used in this post: Ecoli Data Set.

12 responses to “Classification Trees”

  1. Andrew Robinson says:

    Nice post — thanks!

    It would add value if you mentioned the packages that you used.

    Andrew

  2. Ralph says:

    Thanks for pointing that out Andrew – I normally indicate loading packages in my post but must’ve made an oversight when writing this one!

  3. Peter Flom says:

    Another useful package is party developed by Torsten Hothorn

    Also, there is a brand new edition of Recursive Partitioning and Applications by Zhang and Singer. I’m reading it, so far, it is quite good.

  4. Ruben says:

    Nice post but I’m wondering if you could make the ecoli.txt file available.
    When I tried to download the file from the UCI Machine learning repository, the file that contains the data (ecoli.names) doesn’t contain the column names.
    Alternative, would you mind indicating the most efficient way of recreating the txt file used in the example?
    Many thanks in advance,
    Ruben

  5. Ralph says:

    Ruben,

    I’ve added the data file to the end of this post. Hope that it is useful. I had to do some tidying up of the data myself to get it into a suitable format for the analysis described in this post.

    Ralph

  6. Ruben says:

    Dear Ralph,
    thanks a lot for posting the data file.
    I’m looking forward to reading more interesting blog posts about statistical modeling with R.
    Regards,
    Ruben

  7. Bob Muenchen says:

    I enjoyed reading & running this example. Your video adds a very useful command that is missing in the program above. Right after you run cv.tree(ecoli.tree1), in the video you do plot( cv.tree(ecoli.tree1) ) which makes it much more obvious why 6 looks reasonable.

    Keep those tutorials coming!

    Cheers,
    Bob Muenchen

  8. Arslan says:

    Thanks for the tutorial. I have a problem at loading the tree package. When I try it I get this error:

    >require(tree)
    Loading required package: tree
    Warning message:
    In library(package, lib.loc = lib.loc, character.only = TRUE, logical.return = TRUE, :
    there is no package called ‘tree’

    How can I have this “tree package”?

    Thanks in advance.

  9. Ralph says:

    I think that the tree package is not installed by default. If your machine is connected to the internet then you can get R to download and install the package. Otherwise you will need to download from CRAN and install manually. Hope this helps.

  10. Mateus Brum says:

    Very nice post.
    I am wonder how to analyze the prune point of a decision tree.
    Cheers ..

  11. Ralph says:

    Mateus,

    Bob Muenchen (Comment 7 above – about 4:20 in the video) highlights the way to get a feel for how to prune the tree. We are looking for the point where increasing the complexity of the tree no longer leads to an improvement in the model.

    Best wishes
    Ralph

  12. loveness says:

    Is it possible to get misclassification errors per class as in random forest? Like this
    Confusion matrix:
    1 2 class.error
    1 749 240 0.2426694
    2 2274 5731 0.2840725