Python实现决策树(Decision Tree)分类
关于决策树的简介可以参考: http://blog.csdn.net/fengbingchun/article/details/78880934
在 https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/ 中给出了CART(Classification and Regression Trees,分类回归树算法,简称CART)算法的Python实现,采用的数据集为Banknote Dataset,关于此数据集的介绍可以参考:http://blog.csdn.net/fengbingchun/article/details/78624358 ,这里在原作者的基础上,进行了略微改动,使其可以直接执行,code如下:
# reference: https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
# http://zhuanlan.51cto.com/art/201702/531945.htm
# using CART(Classification and Regression Trees,分类回归树算法,简称CART算法)) for classification# CART on the Bank Note dataset
from random import seed
from random import randrange
from csv import reader# Load a CSV file
def load_csv(filename):file = open(filename, "r")lines = reader(file)dataset = list(lines)return dataset# Convert string column to float
def str_column_to_float(dataset, column):for row in dataset:row[column] = float(row[column].strip())# Split a dataset into k folds
def cross_validation_split(dataset, n_folds):dataset_split = list()dataset_copy = list(dataset)fold_size = int(len(dataset) / n_folds)for i in range(n_folds):fold = list()while len(fold) < fold_size:index = randrange(len(dataset_copy))fold.append(dataset_copy.pop(index))dataset_split.append(fold)return dataset_split# Calculate accuracy percentage
def accuracy_metric(actual, predicted):correct = 0for i in range(len(actual)):if actual[i] == predicted[i]:correct += 1return correct / float(len(actual)) * 100.0# Evaluate an algorithm using a cross validation split
def evaluate_algorithm(dataset, algorithm, n_folds, *args):folds = cross_validation_split(dataset, n_folds)scores = list()for fold in folds:train_set = list(folds)train_set.remove(fold)train_set = sum(train_set, [])test_set = list()for row in fold:row_copy = list(row)test_set.append(row_copy)row_copy[-1] = Nonepredicted = algorithm(train_set, test_set, *args)actual = [row[-1] for row in fold]accuracy = accuracy_metric(actual, predicted)scores.append(accuracy)return scores# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):left, right = list(), list()for row in dataset:if row[index] < value:left.append(row)else:right.append(row)return left, right# Calculate the Gini index for a split dataset
def gini_index(groups, classes):# count all samples at split pointn_instances = float(sum([len(group) for group in groups])) # 计算总的样本数# sum weighted Gini index for each groupgini = 0.0for group in groups:size = float(len(group))# avoid divide by zeroif size == 0:continuescore = 0.0# score the group based on the score for each classfor class_val in classes:p = [row[-1] for row in group].count(class_val) / size # row[-1]指每个样本(一行)中最后一列即类别score += p * p# weight the group score by its relative sizegini += (1.0 - score) * (size / n_instances)return gini# Select the best split point for a dataset
def get_split(dataset):class_values = list(set(row[-1] for row in dataset)) # class_values的值为: [0, 1]b_index, b_value, b_score, b_groups = 999, 999, 999, Nonefor index in range(len(dataset[0])-1): # index的值为: [0, 1, 2, 3]for row in dataset:groups = test_split(index, row[index], dataset)gini = gini_index(groups, class_values)if gini < b_score:b_index, b_value, b_score, b_groups = index, row[index], gini, groupsreturn {'index':b_index, 'value':b_value, 'groups':b_groups} # 返回字典数据类型# Create a terminal node value
def to_terminal(group):outcomes = [row[-1] for row in group]return max(set(outcomes), key=outcomes.count)# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):left, right = node['groups']del(node['groups'])# check for a no splitif not left or not right:node['left'] = node['right'] = to_terminal(left + right)return# check for max depthif depth >= max_depth:node['left'], node['right'] = to_terminal(left), to_terminal(right)return# process left childif len(left) <= min_size:node['left'] = to_terminal(left)else:node['left'] = get_split(left)split(node['left'], max_depth, min_size, depth+1)# process right childif len(right) <= min_size:node['right'] = to_terminal(right)else:node['right'] = get_split(right)split(node['right'], max_depth, min_size, depth+1)# Build a decision tree
def build_tree(train, max_depth, min_size):root = get_split(train)split(root, max_depth, min_size, 1)return root# Make a prediction with a decision tree
def predict(node, row):if row[node['index']] < node['value']:if isinstance(node['left'], dict):return predict(node['left'], row)else:return node['left']else:if isinstance(node['right'], dict):return predict(node['right'], row)else:return node['right']# Classification and Regression Tree Algorithm
def decision_tree(train, test, max_depth, min_size):tree = build_tree(train, max_depth, min_size)predictions = list()for row in test:prediction = predict(tree, row)predictions.append(prediction)return(predictions)# Test CART on Bank Note dataset
seed(1)
# load and prepare data
filename = '../../../data/database/BacknoteDataset/data_banknote_authentication.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])):str_column_to_float(dataset, i) # dataset为嵌套列表的列表,类型为float# evaluate algorithm
n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))
执行结果如下:GitHub: https://github.com/fengbingchun/NN_Test
相关文章:

顶尖技术专家严选,15场前沿论坛思辨,2019中国大数据技术大会邀您共赴
扫码了解2019中国大数据技术大会(https://t.csdnimg.cn/IaHb)更多详情。2019中国大数据技术大会(BDTC 2019)将于12月5日-7日在北京长城饭店举办,本届大会将聚焦智能时代,大数据技术的发展曲线以及大数据与社…

jQuery 加法计算 使用+号即强转类型
1 var value1 $("#txt1").val(); 2 var value2 $("#txt2").val(); 3 //数值前添加号 number加号和数值加号需要用空格隔开 即实现加法运算 4 $("#txt3").val(value1 value2); 转载于:https://www.cnblogs.com/xiemin-minmin/p/11026784.…

Android Volley 库通过网络获取 JSON 数据
本文内容 什么是 Volley 库 Volley 能做什么 Volley 架构 环境 演示 Volley 库通过网络获取 JSON 数据 参考资料 Android 关于网络操作一般都会介绍 HttpClient 以及 HttpConnection 这两个包。前者是 Apache 开源库,后者是 Android 自带 API。企业级应用࿰…

哪些开发问题最让程序员“头秃”?我们分析了Stack Overflow的11000个问题
作者 | Nick Roberts编译 | AI科技大本营(ID:rgznai100)自 2008 年成立以来,Stack Overflow 一直在拯救所有类型的开发人员。自那时以来,开发人员提出了数百万个关于开发领域的问题。但是,迫使开发者转向 Stack Overfl…
OpenCV3.3中决策树(Decision Tree)接口简介及使用
OpenCV 3.3中给出了决策树Decision Tres算法的实现,即cv::ml::DTrees类,此类的声明在include/opencv2/ml.hpp文件中,实现在modules/ml/src/tree.cpp文件中。其中:(1)、cv::ml::DTrees类:继承自cv::ml::StateModel&…

ARM 寄存器 和 工作模式了解
一. ARM 工作模式 1. ARM7,ARM9,ARM11,处理器有 7 种工作模式;Cortex-A 多了一个监视模式(Monitor) 2. 用户模式:非特权模式,大部分任务执行在这种模式,它运行在操作系…

英文版PDF不能显示中文PDF文件的解决方法
首先,PDF如果是英文版本的话,先装一个与之对应的PDF中文包。装上之后要检查的两项:1、PDF本身打开Adobe pdf选择“edit”"Preference""Internet"将"internet"下的三个勾全部勾上"OK"2、IE设置打开IE…

Linux下__attribute__((visibility (default)))的使用
在Linux下动态库(.so)中,通过GCC的C visibility属性可以控制共享文件导出符号。在GCC 4.0及以上版本中,有个visibility属性,可见属性可以应用到函数、变量、模板以及C类。 限制符号可见性的原因:从动态库中尽可能少地输出符号是一…

java web学习项目20套源码完整版
java web学习项目20套源码完整版 自己收集的各行各业的都有,这一套源码吃遍所有作业项目! 1、BBS论坛系统(jspsql)2、ERP管理系统(jspservlet)3、OA办公自动化管理系统(Struts1.2Hibernate3.0Spring2DWR)4、…

360金融携手上海交大共建AI实验室,开启人才战略新布局
10月16日,上海交通大学计算机科学系—360金融人工智能联合实验室成立仪式在上海交通大学闵行校区举行,联合实验室致力于AI技术在新金融领域的应用探索。成立仪式上,360金融CEO吴海生宣布了“未来科学家”计划,这是360金融在人工智…
wxWidgets刚開始学习的人导引(3)——wxWidgets应用程序初体验
wxWidgets刚開始学习的人导引全文件夹 PDF版及附件下载1 前言2 下载、安装wxWidgets3 wxWidgets应用程序初体验4 wxWidgets学习资料及利用方法指导5 用wxSmith进行可视化设计附:学习材料清单3 wxWidgets应用程序初体验本文中全部的体验,在Code::Blocks…

C++中extern的使用
在C中,extern主要有两个作用:(1)、extern声明一个变量或函数;(2)、extern与”C”一起连用,用于链接指定。关于extern “C”的使用可以参考: http://blog.csdn.net/fengbingchun/article/details/78634831 ,…

Python识别文字,实现看图说话 | CSDN博文精选
作者 | 张小腿来源 | CSDN博客现在写文件很多网站都不让复制了,所以每次都是截图然后发到QQ上然后用手机QQ的文字识别再发回电脑。感觉有点小麻烦了,所以想自己写一个小软件方便方便自己,就有了这篇了:首先语言是Python࿰…

Oracle Hints具体解释
在向大家具体介绍Oracle Hints之前,首先让大家了解下Oracle Hints是什么,然后全面介绍Oracle Hints,希望对大家实用。基于代价的优化器是非常聪明的,在绝大多数情况下它会选择正确的优化器,减轻了DBA的负担。但有时它也…
主成分分析(PCA)简介
主成分分析(Principal Components Analysis, PCA)是一个简单的机器学习算法,可以通过基础的线性代数知识推导。假设在Rn空间中我们有m个点{x(1),…,x(m)},我们希望对这些点进行有损压缩。有损压缩表示我们使用更少的内存,但损失一些精度去存储…

01-HTML基础与进阶-day6-录像281
04css选择器.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Document</title><style type"text/css">/* p div 标签选择器*/p {color: red; /* k:v color表示样式属性 颜…

百度CTO王海峰:深度学习如何大规模产业化?
编者按:10月17日-19日,2019年中国计算机大会(CNCC2019)在苏州举办。百度首席技术官王海峰在会上发表题为《深度学习平台支撑产业智能化》的演讲,分享了百度关于深度学习技术推动人工智能发展及产业化应用的思考。以下为…

Kali Linux***测试
Kali Linux***测试实战 第一章http://drops.wooyun.org/tips/826 1.1 Kali Linux简介如果您之前使用过或者了解BackTrack系列Linux的话,那么我只需要简单的说,Kali是BackTrack的升级换代产品,从Kali开始,BackTrack将成为历史。如果…

一站式解决:隐马尔可夫模型(HMM)全过程推导及实现
作者 | 永远在你身后转载自知乎用户永远在你身后【导读】隐马尔可夫模型(Hidden Markov Model,HMM)是关于时许的概率模型,是一个生成模型,描述由一个隐藏的马尔科夫链随机生成不可观测的状态序列,每个状态生…
CUDA Samples: Image Process: BGR to BGR565
图像像素格式BGR565是每一个像素占2个字节,其中Blue占5位,Green占6位,Red占5位。在OpenCV中,BGR到BGR565的每一个像素的计算公式是:unsigned short dst (unsigned short)((B >> 3) | ((G & ~3) << 3)…

NoSQL数据库探讨 - 为什么要用非关系数据库?
源地址:http://robbin.javaeye.com/blog/524977 随着互联网web2.0网站的兴起,非关系型的数据库现在成了一个极其热门的新领域,非关系数据库产品的发展非常迅速。而传统的关系数据库在应付web2.0网站,特别是超大规模和高并发的SNS类…

手机内存RAM、ROM简介
手机内存包含两个:一个是运行内存(RAM),一个是机身内存(ROM)。两者的功能有所不同,运行内存是对手机操作系统和其它程序运行过程中,产生的临时数据进行存储的媒介。如果手机运行的程序比较多,占用运行内存空间较大&…

一个月入门Python爬虫,轻松爬取大规模数据
如果你仔细观察,就不难发现,懂爬虫、学习爬虫的人越来越多,一方面,互联网可以获取的数据越来越多,另一方面,像 Python这样一个月入门Python爬虫,轻松爬的编程语言提供越来越多的优秀工具&#x…

软件包管理 之 软件在线升级更新yum 图形工具介绍
作者:北南南北来自:LinuxSir.Org提要:yum 是Fedora/Redhat 软件包管理工具,包括文本命令行模式和图形模式;图形模式的yum也是基于文本模式的;目前yum图形前端程序主要有 yumex和kyum ; 正文一、…

[PHPUnit]自动生成PHPUnit测试骨架脚本-提供您的开发效率【2015升级版】
2019独角兽企业重金招聘Python工程师标准>>> 场景 在编写PHPUnit单元测试代码时,其实很多都是对各个类的各个外部调用的函数进行测试验证,检测代码覆盖率,验证预期效果。为避免增加开发量,可以使用PHPUnit提供的phpuni…
ORL Faces Database介绍
ORL人脸数据集共包含40个不同人的400张图像,是在1992年4月至1994年4月期间由英国剑桥的Olivetti研究实验室创建。此数据集下包含40个目录,每个目录下有10张图像,每个目录表示一个不同的人。所有的图像是以PGM格式存储,灰度图&…

张俊林:BERT和Transformer到底学到了什么 | AI ProCon 2019
演讲嘉宾 | 张俊林(新浪微博机器学习团队AI Lab负责人)编辑 | Jane出品 | AI科技大本营(ID:rgznai100)【导读】BERT提出的这一年,也是NLP领域迅速发展的一年。学界不断提出新的预训练模型,刷新各…

Eclipse创建web工程时,报错Dynamic Web Module 3.0 requires Java 1.6 or newer.
报错: 解决方案: 1.打开eclipse工具栏window->preferences 2.打开java->compiler 3.选择compiler compliance level在1.6以上版本(此处选择1.8) 4.点击apply and close保存修改,即可 转载于:https://www.cnblogs…

Maven学习总结(八)——使用Maven构建多模块项目
2019独角兽企业重金招聘Python工程师标准>>> Maven学习总结(八)——使用Maven构建多模块项目 在平时的Javaweb项目开发中为了便于后期的维护,我们一般会进行分层开发,最常见的就是分为domain(域模型层)、dao࿰…

哈工大、清华、CSDN、嵌入式视觉联盟合办的 AIoT 盛会,你怎么舍得错过?!
2019 嵌入式智能国际大会即将来袭!随着物联网和人工智能技术的飞速发展与相互渗透,万物智联的新赛道已经开始显现。据中商产业研究院《2016—2021年中国物联网产业市场研究报告》显示,预计到2020年,中国物联网的整体规模将达2.2万…