1212class Tree (object ):
1313 """Recursive implementation of decision tree."""
1414
15- def __init__ (self , regression = False , criterion = None ):
15+ def __init__ (self , regression = False , criterion = None , n_classes = None ):
1616 self .regression = regression
1717 self .impurity = None
1818 self .threshold = None
1919 self .column_index = None
2020 self .outcome = None
2121 self .criterion = criterion
2222 self .loss = None
23+ self .n_classes = n_classes # Only for classification
2324
2425 self .left_child = None
2526 self .right_child = None
@@ -64,6 +65,42 @@ def _find_best_split(self, X, target, n_features):
6465 max_col , max_val , max_gain = column , value , gain
6566 return max_col , max_val , max_gain
6667
68+ def _train (self , X , target , max_features = None , min_samples_split = 10 , max_depth = None , minimum_gain = 0.01 ):
69+ try :
70+ # Exit from recursion using assert syntax
71+ assert X .shape [0 ] > min_samples_split
72+ assert max_depth > 0
73+
74+ if max_features is None :
75+ max_features = X .shape [1 ]
76+
77+ column , value , gain = self ._find_best_split (X , target , max_features )
78+ assert gain is not None
79+ if self .regression :
80+ assert gain != 0
81+ else :
82+ assert gain > minimum_gain
83+
84+ self .column_index = column
85+ self .threshold = value
86+ self .impurity = gain
87+
88+ # Split dataset
89+ left_X , right_X , left_target , right_target = split_dataset (X , target , column , value )
90+
91+ # Grow left and right child
92+ self .left_child = Tree (self .regression , self .criterion , self .n_classes )
93+ self .left_child ._train (
94+ left_X , left_target , max_features , min_samples_split , max_depth - 1 , minimum_gain
95+ )
96+
97+ self .right_child = Tree (self .regression , self .criterion , self .n_classes )
98+ self .right_child ._train (
99+ right_X , right_target , max_features , min_samples_split , max_depth - 1 , minimum_gain
100+ )
101+ except AssertionError :
102+ self ._calculate_leaf_value (target )
103+
67104 def train (self , X , target , max_features = None , min_samples_split = 10 , max_depth = None , minimum_gain = 0.01 , loss = None ):
68105 """Build a decision tree from training set.
69106
@@ -93,40 +130,12 @@ def train(self, X, target, max_features=None, min_samples_split=10, max_depth=No
93130 if loss is not None :
94131 self .loss = loss
95132
96- try :
97- # Exit from recursion using assert syntax
98- assert X .shape [0 ] > min_samples_split
99- assert max_depth > 0
100-
101- if max_features is None :
102- max_features = X .shape [1 ]
133+ if not self .regression :
134+ self .n_classes = len (np .unique (target ['y' ]))
103135
104- column , value , gain = self ._find_best_split (X , target , max_features )
105- assert gain is not None
106- if self .regression :
107- assert gain != 0
108- else :
109- assert gain > minimum_gain
136+ self ._train (X , target , max_features = max_features , min_samples_split = min_samples_split ,
137+ max_depth = max_depth , minimum_gain = minimum_gain )
110138
111- self .column_index = column
112- self .threshold = value
113- self .impurity = gain
114-
115- # Split dataset
116- left_X , right_X , left_target , right_target = split_dataset (X , target , column , value )
117-
118- # Grow left and right child
119- self .left_child = Tree (self .regression , self .criterion )
120- self .left_child .train (
121- left_X , left_target , max_features , min_samples_split , max_depth - 1 , minimum_gain , loss
122- )
123-
124- self .right_child = Tree (self .regression , self .criterion )
125- self .right_child .train (
126- right_X , right_target , max_features , min_samples_split , max_depth - 1 , minimum_gain , loss
127- )
128- except AssertionError :
129- self ._calculate_leaf_value (target )
130139
131140 def _calculate_leaf_value (self , targets ):
132141 """Find optimal value for leaf."""
@@ -140,7 +149,7 @@ def _calculate_leaf_value(self, targets):
140149 self .outcome = np .mean (targets ["y" ])
141150 else :
142151 # Probability for classification task
143- self .outcome = stats . itemfreq (targets ["y" ])[:, 1 ] / float ( targets ["y" ].shape [0 ])
152+ self .outcome = np . bincount (targets ["y" ], minlength = self . n_classes ) / targets ["y" ].shape [0 ]
144153
145154 def predict_row (self , row ):
146155 """Predict single row."""
0 commit comments