Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Iurii Kemaev committed Sep 18, 2018
1 parent 90b01f5 commit 0caab13
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 38 deletions.
Empty file removed 123
Empty file.
41 changes: 11 additions & 30 deletions Cart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,27 @@ def __init__(self):
#type : 0 - list, 1 - node

def sum(self, X, p):
s = 0
for i in range(0, len(X)):
s += float(X[i][p])
return s
return sum(map(float, X[:][p]))

def create_list(self, X, y):
self.type = 0
self.attribute = [0, 0, 0, 0]
self.answer = sum(y) / float(len(X))


def is_list_criteria(self, depth):
if depth < 3: # Maximum DEPTH
return 0
else:
return 1
return depth >= 3 # Maximal DEPTH

def get_params(self, deep = True):
return {}

#We divide all members into two classes
#In first current attribute i < average of all, in second >=
# We divide all members into two classes.
# In the first one a current attribute _i_ < average of all, in the second >=.
@staticmethod
def gini_impurity(self, X, y, depth):
if self.is_list_criteria(depth):
self.create_list(X, y)
return self

self.type = 1
self.attribute[3] = -1
for i in range(0, len(X[0])):
Expand Down Expand Up @@ -75,17 +71,8 @@ def gini_impurity(self, X, y, depth):
self.create_list(X, y)
return self

def build_tree(self, X, y):
self = self.gini_impurity(self, X, y, 0)
return self

def fit(self, X, y):
t = self.build_tree(X, y)
self.attribute = t.attribute
self.answer = t.answer
self.type = t.type
self.left = t.left
self.right = t.right
return self.gini_impurity(self, X, y, 0)

def predict_user(self, user):
if self.type:
Expand All @@ -97,13 +84,7 @@ def predict_user(self, user):
return self.answer

def predict(self, X):
res = []
for i in range(0, len(X)):
res.append(self.predict_user(X[i]))
return res

def get_params(self, deep = True):
return {}
return list(map(self.predict_user, X))

def score(self, X, y):
y_prediction = numpy.array(self.predict(X)) - numpy.array(y)
Expand Down Expand Up @@ -147,7 +128,7 @@ def launch(flearn, ftest):
tk = line.split(',')
t = [float(tk[3]), float(tk[4]), float(tk[6]), float(tk[7]), float(tk[8])]
temp = CART.predict_user(t)
print(str([tk + ' prediction is ' + temp]))
print(''.join(map(str, [tk, ' prediction is ', temp])))
sumsq += (float(tk[-1]) - temp) ** 2
i += 1
line = f.readline()
Expand All @@ -156,7 +137,7 @@ def launch(flearn, ftest):
tk = tk[2].split(',')
t = [float(tk[1]), float(tk[2]), float(tk[4]), float(tk[5]), float(tk[6])]
temp = CART.predict_user(t)
print(str([tk + ' prediction is ' + temp]))
print(''.join(map(str, [tk, ' prediction is ', temp])))
sumsq += (float(tk[-1]) - temp) ** 2
i += 1
line = f.readline()
Expand Down
14 changes: 6 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
Decision_tree-CART-
===================
### Decision tree impelentation, *CART*

Homework for Data Mining
* learn.csv - train set
* test.csv - test set
* users.csv - the full data

learn.csv - обучающая выборка test.csv - тестовая выборка
Run and test: `python3 Cart.py learn.csv users.csv`

users.csv - общий набор данных

обучить и протестировать: python Cart.py learn.csv userfilt.csv
кросс-валидация: python Cart.py cross-val users.csv
Cross-validate: `python3 Cart.py cross-val users.csv`

0 comments on commit 0caab13

Please sign in to comment.