首页 后端开发 Python教程 python如何实现决策树算法?(代码)

python如何实现决策树算法?(代码)

Oct 10, 2018 pm 05:16 PM
python 人工智能 大数据 机器学习

本篇文章给大家带来的内容是关于python如何实现决策树算法?(代码),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。

数据描述

每条数据项储存在列表中,最后一列储存结果
多条数据项形成数据集

data=[[d1,d2,d3...dn,result],
      [d1,d2,d3...dn,result],
                .
                .
      [d1,d2,d3...dn,result]]
登录后复制

决策树数据结构

class DecisionNode:
    '''决策树节点
    '''
    
    def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
        '''初始化决策树节点
        
        args:        
        col -- 按数据集的col列划分数据集
        value -- 以value作为划分col列的参照
        result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果’:结果出现次数}
        rb,fb -- 代表左右子树
        '''
        self.col=col
        self.value=value
        self.results=results
        self.tb=tb
        self.fb=fb
登录后复制

决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果’:结果出现次数}的方式表达每个子集

def pideset(rows,column,value):
    '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
       返回两个数据集
    '''
    split_function=None
    #value是数值类型
    if isinstance(value,int) or isinstance(value,float):
        #定义lambda函数当row[column]>=value时返回true
        split_function=lambda row:row[column]>=value
    #value是字符类型
    else:
        #定义lambda函数当row[column]==value时返回true
        split_function=lambda row:row[column]==value
    #将数据集拆分成两个
    set1=[row for row in rows if split_function(row)]
    set2=[row for row in rows if not split_function(row)]
    #返回两个数据集
    return (set1,set2)

def uniquecounts(rows):
    '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
    '''
    results={}
    for row in rows:
        r=row[len(row)-1]
        if r not in results: results[r]=0
        results[r]+=1
    return results

def giniimpurity(rows):
    '''返回rows数据集的基尼不纯度
    '''
    total=len(rows)
    counts=uniquecounts(rows)
    imp=0
    for k1 in counts:
        p1=float(counts[k1])/total
        for k2 in counts:
            if k1==k2: continue
            p2=float(counts[k2])/total
            imp+=p1*p2
    return imp

def entropy(rows):
    '''返回rows数据集的熵
    '''
    from math import log
    log2=lambda x:log(x)/log(2)  
    results=uniquecounts(rows)
    ent=0.0
    for r in results.keys():
        p=float(results[r])/len(rows)
        ent=ent-p*log2(p)
    return ent

def build_tree(rows,scoref=entropy):
    '''构造决策树
    '''
    if len(rows)==0: return DecisionNode()
    current_score=scoref(rows)

    # 最佳信息增益
    best_gain=0.0
    #
    best_criteria=None
    #最佳划分
    best_sets=None

    column_count=len(rows[0])-1
    #遍历数据集的列,确定分割顺序
    for col in range(0,column_count):
        column_values={}
        # 构造字典
        for row in rows:
            column_values[row[col]]=1
        for value in column_values.keys():
            (set1,set2)=pideset(rows,col,value)
            p=float(len(set1))/len(rows)
            # 计算信息增益
            gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
            if gain>best_gain and len(set1)>0 and len(set2)>0:
                best_gain=gain
                best_criteria=(col,value)
                best_sets=(set1,set2)
    # 如果划分的两个数据集熵小于原数据集,进一步划分它们
    if best_gain>0:
        trueBranch=build_tree(best_sets[0])
        falseBranch=build_tree(best_sets[1])
        return DecisionNode(col=best_criteria[0],value=best_criteria[1],
                        tb=trueBranch,fb=falseBranch)
    # 如果划分的两个数据集熵不小于原数据集,停止划分
    else:
        return DecisionNode(results=uniquecounts(rows))

def print_tree(tree,indent=''):
    if tree.results!=None:
        print(str(tree.results))
    else:
        print(str(tree.col)+':'+str(tree.value)+'? ')
        print(indent+'T->',end='')
        print_tree(tree.tb,indent+'  ')
        print(indent+'F->',end='')
        print_tree(tree.fb,indent+'  ')


def getwidth(tree):
    if tree.tb==None and tree.fb==None: return 1
    return getwidth(tree.tb)+getwidth(tree.fb)

def getdepth(tree):
    if tree.tb==None and tree.fb==None: return 0
    return max(getdepth(tree.tb),getdepth(tree.fb))+1


def drawtree(tree,jpeg='tree.jpg'):
    w=getwidth(tree)*100
    h=getdepth(tree)*100+120

    img=Image.new('RGB',(w,h),(255,255,255))
    draw=ImageDraw.Draw(img)

    drawnode(draw,tree,w/2,20)
    img.save(jpeg,'JPEG')

def drawnode(draw,tree,x,y):
    if tree.results==None:
        # Get the width of each branch
        w1=getwidth(tree.fb)*100
        w2=getwidth(tree.tb)*100

        # Determine the total space required by this node
        left=x-(w1+w2)/2
        right=x+(w1+w2)/2

        # Draw the condition string
        draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))

        # Draw links to the branches
        draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
        draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
    
        # Draw the branch nodes
        drawnode(draw,tree.fb,left+w1/2,y+100)
        drawnode(draw,tree.tb,right-w2/2,y+100)
    else:
        txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
        draw.text((x-20,y),txt,(0,0,0))
登录后复制

对测试数据进行分类(附带处理缺失数据)

def mdclassify(observation,tree):
    '''对缺失数据进行分类
    
    args:
    observation -- 发生信息缺失的数据项
    tree -- 训练完成的决策树
    
    返回代表该分类的结果字典
    '''

    # 判断数据是否到达叶节点
    if tree.results!=None:
        # 已经到达叶节点,返回结果result
        return tree.results
    else:
        # 对数据项的col列进行分析
        v=observation[tree.col]

        # 若col列数据缺失
        if v==None:
            #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
            tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)

            # 分别以结果占总数比例计算得到左右子树的权重
            tcount=sum(tr.values())
            fcount=sum(fr.values())
            tw=float(tcount)/(tcount+fcount)
            fw=float(fcount)/(tcount+fcount)
            result={}

            # 计算左右子树的加权平均
            for k,v in tr.items(): 
                result[k]=v*tw
            for k,v in fr.items(): 
                # fr的结果k有可能并不在tr中,在result中初始化k
                if k not in result: 
                    result[k]=0 
                # fr的结果累加到result中  
                result[k]+=v*fw
            return result

        # col列没有缺失,继续沿决策树分类
        else:
            if isinstance(v,int) or isinstance(v,float):
                if v>=tree.value: branch=tree.tb
                else: branch=tree.fb
            else:
                if v==tree.value: branch=tree.tb
                else: branch=tree.fb
            return mdclassify(observation,branch)

tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))
登录后复制

决策树剪枝

def prune(tree,mingain):
    '''对决策树进行剪枝
    
    args:
    tree -- 决策树
    mingain -- 最小信息增益
    
   返回
    '''
    # 修剪非叶节点
    if tree.tb.results==None:
        prune(tree.tb,mingain)
    if tree.fb.results==None:
        prune(tree.fb,mingain)
    #合并两个叶子节点
    if tree.tb.results!=None and tree.fb.results!=None:
        tb,fb=[],[]
        for v,c in tree.tb.results.items():
            tb+=[[v]]*c
        for v,c in tree.fb.results.items():
            fb+=[[v]]*c
        #计算熵减少情况
        delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
        #熵的增加量小于mingain,可以合并分支
        if delta<mingain:
            tree.tb,tree.fb=None,None
            tree.results=uniquecounts(tb+fb)
登录后复制

以上是python如何实现决策树算法?(代码)的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热门文章

R.E.P.O.能量晶体解释及其做什么(黄色晶体)
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳图形设置
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您听不到任何人,如何修复音频
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解锁Myrise中的所有内容
4 周前 By 尊渡假赌尊渡假赌尊渡假赌

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

mysql 是否要付费 mysql 是否要付费 Apr 08, 2025 pm 05:36 PM

MySQL 有免费的社区版和收费的企业版。社区版可免费使用和修改,但支持有限,适合稳定性要求不高、技术能力强的应用。企业版提供全面商业支持,适合需要稳定可靠、高性能数据库且愿意为支持买单的应用。选择版本时考虑的因素包括应用关键性、预算和技术技能。没有完美的选项,只有最合适的方案,需根据具体情况谨慎选择。

mysql安装后怎么使用 mysql安装后怎么使用 Apr 08, 2025 am 11:48 AM

文章介绍了MySQL数据库的上手操作。首先,需安装MySQL客户端,如MySQLWorkbench或命令行客户端。1.使用mysql-uroot-p命令连接服务器,并使用root账户密码登录;2.使用CREATEDATABASE创建数据库,USE选择数据库;3.使用CREATETABLE创建表,定义字段及数据类型;4.使用INSERTINTO插入数据,SELECT查询数据,UPDATE更新数据,DELETE删除数据。熟练掌握这些步骤,并学习处理常见问题和优化数据库性能,才能高效使用MySQL。

mySQL下载完安装不了 mySQL下载完安装不了 Apr 08, 2025 am 11:24 AM

MySQL安装失败的原因主要有:1.权限问题,需以管理员身份运行或使用sudo命令;2.依赖项缺失,需安装相关开发包;3.端口冲突,需关闭占用3306端口的程序或修改配置文件;4.安装包损坏,需重新下载并验证完整性;5.环境变量配置错误,需根据操作系统正确配置环境变量。解决这些问题,仔细检查每个步骤,就能顺利安装MySQL。

如何针对高负载应用程序优化 MySQL 性能? 如何针对高负载应用程序优化 MySQL 性能? Apr 08, 2025 pm 06:03 PM

MySQL数据库性能优化指南在资源密集型应用中,MySQL数据库扮演着至关重要的角色,负责管理海量事务。然而,随着应用规模的扩大,数据库性能瓶颈往往成为制约因素。本文将探讨一系列行之有效的MySQL性能优化策略,确保您的应用在高负载下依然保持高效响应。我们将结合实际案例,深入讲解索引、查询优化、数据库设计以及缓存等关键技术。1.数据库架构设计优化合理的数据库架构是MySQL性能优化的基石。以下是一些核心原则:选择合适的数据类型选择最小的、符合需求的数据类型,既能节省存储空间,又能提升数据处理速度

mysql安装后怎么优化数据库性能 mysql安装后怎么优化数据库性能 Apr 08, 2025 am 11:36 AM

MySQL性能优化需从安装配置、索引及查询优化、监控与调优三个方面入手。1.安装后需根据服务器配置调整my.cnf文件,例如innodb_buffer_pool_size参数,并关闭query_cache_size;2.创建合适的索引,避免索引过多,并优化查询语句,例如使用EXPLAIN命令分析执行计划;3.利用MySQL自带监控工具(SHOWPROCESSLIST,SHOWSTATUS)监控数据库运行状况,定期备份和整理数据库。通过这些步骤,持续优化,才能提升MySQL数据库性能。

mysql 需要互联网吗 mysql 需要互联网吗 Apr 08, 2025 pm 02:18 PM

MySQL 可在无需网络连接的情况下运行,进行基本的数据存储和管理。但是,对于与其他系统交互、远程访问或使用高级功能(如复制和集群)的情况,则需要网络连接。此外,安全措施(如防火墙)、性能优化(选择合适的网络连接)和数据备份对于连接到互联网的 MySQL 数据库至关重要。

Navicat查看MongoDB数据库密码的方法 Navicat查看MongoDB数据库密码的方法 Apr 08, 2025 pm 09:39 PM

直接通过 Navicat 查看 MongoDB 密码是不可能的,因为它以哈希值形式存储。取回丢失密码的方法:1. 重置密码;2. 检查配置文件(可能包含哈希值);3. 检查代码(可能硬编码密码)。

HadiDB:Python 中的轻量级、可水平扩展的数据库 HadiDB:Python 中的轻量级、可水平扩展的数据库 Apr 08, 2025 pm 06:12 PM

HadiDB:轻量级、高水平可扩展的Python数据库HadiDB(hadidb)是一个用Python编写的轻量级数据库,具备高度水平的可扩展性。安装HadiDB使用pip安装:pipinstallhadidb用户管理创建用户:createuser()方法创建一个新用户。authentication()方法验证用户身份。fromhadidb.operationimportuseruser_obj=user("admin","admin")user_obj.

See all articles