首頁 後端開發 Python教學 python如何實作決策樹演算法?(程式碼)

python如何實作決策樹演算法?(程式碼)

Oct 10, 2018 pm 05:16 PM
python 人工智慧 大數據 機器學習

這篇文章帶給大家的內容是關於python如何實現決策樹演算法?(程式碼),有一定的參考價值,有需要的朋友可以參考一下,希望對你有幫助。

資料描述

每個資料項目儲存在清單中,最後一列儲存結果
多條資料項目形成資料集

data=[[d1,d2,d3...dn,result],
      [d1,d2,d3...dn,result],
                .
                .
      [d1,d2,d3...dn,result]]
登入後複製

決策樹資料結構

class DecisionNode:
    '''决策树节点
    '''
    
    def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
        '''初始化决策树节点
        
        args:        
        col -- 按数据集的col列划分数据集
        value -- 以value作为划分col列的参照
        result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果’:结果出现次数}
        rb,fb -- 代表左右子树
        '''
        self.col=col
        self.value=value
        self.results=results
        self.tb=tb
        self.fb=fb
登入後複製

決策樹分類的最終結果是將資料項目劃分出了若干子集,其中每個子集的結果都一樣,所以這裡採用{'結果':結果出現次數}的方式表達每個子集

def pideset(rows,column,value):
    '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
       返回两个数据集
    '''
    split_function=None
    #value是数值类型
    if isinstance(value,int) or isinstance(value,float):
        #定义lambda函数当row[column]>=value时返回true
        split_function=lambda row:row[column]>=value
    #value是字符类型
    else:
        #定义lambda函数当row[column]==value时返回true
        split_function=lambda row:row[column]==value
    #将数据集拆分成两个
    set1=[row for row in rows if split_function(row)]
    set2=[row for row in rows if not split_function(row)]
    #返回两个数据集
    return (set1,set2)

def uniquecounts(rows):
    '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
    '''
    results={}
    for row in rows:
        r=row[len(row)-1]
        if r not in results: results[r]=0
        results[r]+=1
    return results

def giniimpurity(rows):
    '''返回rows数据集的基尼不纯度
    '''
    total=len(rows)
    counts=uniquecounts(rows)
    imp=0
    for k1 in counts:
        p1=float(counts[k1])/total
        for k2 in counts:
            if k1==k2: continue
            p2=float(counts[k2])/total
            imp+=p1*p2
    return imp

def entropy(rows):
    '''返回rows数据集的熵
    '''
    from math import log
    log2=lambda x:log(x)/log(2)  
    results=uniquecounts(rows)
    ent=0.0
    for r in results.keys():
        p=float(results[r])/len(rows)
        ent=ent-p*log2(p)
    return ent

def build_tree(rows,scoref=entropy):
    '''构造决策树
    '''
    if len(rows)==0: return DecisionNode()
    current_score=scoref(rows)

    # 最佳信息增益
    best_gain=0.0
    #
    best_criteria=None
    #最佳划分
    best_sets=None

    column_count=len(rows[0])-1
    #遍历数据集的列,确定分割顺序
    for col in range(0,column_count):
        column_values={}
        # 构造字典
        for row in rows:
            column_values[row[col]]=1
        for value in column_values.keys():
            (set1,set2)=pideset(rows,col,value)
            p=float(len(set1))/len(rows)
            # 计算信息增益
            gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
            if gain>best_gain and len(set1)>0 and len(set2)>0:
                best_gain=gain
                best_criteria=(col,value)
                best_sets=(set1,set2)
    # 如果划分的两个数据集熵小于原数据集,进一步划分它们
    if best_gain>0:
        trueBranch=build_tree(best_sets[0])
        falseBranch=build_tree(best_sets[1])
        return DecisionNode(col=best_criteria[0],value=best_criteria[1],
                        tb=trueBranch,fb=falseBranch)
    # 如果划分的两个数据集熵不小于原数据集,停止划分
    else:
        return DecisionNode(results=uniquecounts(rows))

def print_tree(tree,indent=''):
    if tree.results!=None:
        print(str(tree.results))
    else:
        print(str(tree.col)+':'+str(tree.value)+'? ')
        print(indent+'T->',end='')
        print_tree(tree.tb,indent+'  ')
        print(indent+'F->',end='')
        print_tree(tree.fb,indent+'  ')


def getwidth(tree):
    if tree.tb==None and tree.fb==None: return 1
    return getwidth(tree.tb)+getwidth(tree.fb)

def getdepth(tree):
    if tree.tb==None and tree.fb==None: return 0
    return max(getdepth(tree.tb),getdepth(tree.fb))+1


def drawtree(tree,jpeg='tree.jpg'):
    w=getwidth(tree)*100
    h=getdepth(tree)*100+120

    img=Image.new('RGB',(w,h),(255,255,255))
    draw=ImageDraw.Draw(img)

    drawnode(draw,tree,w/2,20)
    img.save(jpeg,'JPEG')

def drawnode(draw,tree,x,y):
    if tree.results==None:
        # Get the width of each branch
        w1=getwidth(tree.fb)*100
        w2=getwidth(tree.tb)*100

        # Determine the total space required by this node
        left=x-(w1+w2)/2
        right=x+(w1+w2)/2

        # Draw the condition string
        draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))

        # Draw links to the branches
        draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
        draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
    
        # Draw the branch nodes
        drawnode(draw,tree.fb,left+w1/2,y+100)
        drawnode(draw,tree.tb,right-w2/2,y+100)
    else:
        txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
        draw.text((x-20,y),txt,(0,0,0))
登入後複製

對測試資料進行分類(附帶處理缺失資料)

#
def mdclassify(observation,tree):
    '''对缺失数据进行分类
    
    args:
    observation -- 发生信息缺失的数据项
    tree -- 训练完成的决策树
    
    返回代表该分类的结果字典
    '''

    # 判断数据是否到达叶节点
    if tree.results!=None:
        # 已经到达叶节点,返回结果result
        return tree.results
    else:
        # 对数据项的col列进行分析
        v=observation[tree.col]

        # 若col列数据缺失
        if v==None:
            #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
            tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)

            # 分别以结果占总数比例计算得到左右子树的权重
            tcount=sum(tr.values())
            fcount=sum(fr.values())
            tw=float(tcount)/(tcount+fcount)
            fw=float(fcount)/(tcount+fcount)
            result={}

            # 计算左右子树的加权平均
            for k,v in tr.items(): 
                result[k]=v*tw
            for k,v in fr.items(): 
                # fr的结果k有可能并不在tr中,在result中初始化k
                if k not in result: 
                    result[k]=0 
                # fr的结果累加到result中  
                result[k]+=v*fw
            return result

        # col列没有缺失,继续沿决策树分类
        else:
            if isinstance(v,int) or isinstance(v,float):
                if v>=tree.value: branch=tree.tb
                else: branch=tree.fb
            else:
                if v==tree.value: branch=tree.tb
                else: branch=tree.fb
            return mdclassify(observation,branch)

tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))
登入後複製

決策樹剪枝

def prune(tree,mingain):
    '''对决策树进行剪枝
    
    args:
    tree -- 决策树
    mingain -- 最小信息增益
    
   返回
    '''
    # 修剪非叶节点
    if tree.tb.results==None:
        prune(tree.tb,mingain)
    if tree.fb.results==None:
        prune(tree.fb,mingain)
    #合并两个叶子节点
    if tree.tb.results!=None and tree.fb.results!=None:
        tb,fb=[],[]
        for v,c in tree.tb.results.items():
            tb+=[[v]]*c
        for v,c in tree.fb.results.items():
            fb+=[[v]]*c
        #计算熵减少情况
        delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
        #熵的增加量小于mingain,可以合并分支
        if delta<mingain:
            tree.tb,tree.fb=None,None
            tree.results=uniquecounts(tb+fb)
登入後複製

以上是python如何實作決策樹演算法?(程式碼)的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

AI Hentai Generator

AI Hentai Generator

免費產生 AI 無盡。

熱門文章

R.E.P.O.能量晶體解釋及其做什麼(黃色晶體)
3 週前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳圖形設置
3 週前 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您聽不到任何人,如何修復音頻
3 週前 By 尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解鎖Myrise中的所有內容
4 週前 By 尊渡假赌尊渡假赌尊渡假赌

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

mysql 是否要付費 mysql 是否要付費 Apr 08, 2025 pm 05:36 PM

MySQL 有免費的社區版和收費的企業版。社區版可免費使用和修改,但支持有限,適合穩定性要求不高、技術能力強的應用。企業版提供全面商業支持,適合需要穩定可靠、高性能數據庫且願意為支持買單的應用。選擇版本時考慮的因素包括應用關鍵性、預算和技術技能。沒有完美的選項,只有最合適的方案,需根據具體情況謹慎選擇。

mysql安裝後怎麼使用 mysql安裝後怎麼使用 Apr 08, 2025 am 11:48 AM

文章介紹了MySQL數據庫的上手操作。首先,需安裝MySQL客戶端,如MySQLWorkbench或命令行客戶端。 1.使用mysql-uroot-p命令連接服務器,並使用root賬戶密碼登錄;2.使用CREATEDATABASE創建數據庫,USE選擇數據庫;3.使用CREATETABLE創建表,定義字段及數據類型;4.使用INSERTINTO插入數據,SELECT查詢數據,UPDATE更新數據,DELETE刪除數據。熟練掌握這些步驟,並學習處理常見問題和優化數據庫性能,才能高效使用MySQL。

mysql 需要互聯網嗎 mysql 需要互聯網嗎 Apr 08, 2025 pm 02:18 PM

MySQL 可在無需網絡連接的情況下運行,進行基本的數據存儲和管理。但是,對於與其他系統交互、遠程訪問或使用高級功能(如復制和集群)的情況,則需要網絡連接。此外,安全措施(如防火牆)、性能優化(選擇合適的網絡連接)和數據備份對於連接到互聯網的 MySQL 數據庫至關重要。

如何針對高負載應用程序優化 MySQL 性能? 如何針對高負載應用程序優化 MySQL 性能? Apr 08, 2025 pm 06:03 PM

MySQL數據庫性能優化指南在資源密集型應用中,MySQL數據庫扮演著至關重要的角色,負責管理海量事務。然而,隨著應用規模的擴大,數據庫性能瓶頸往往成為製約因素。本文將探討一系列行之有效的MySQL性能優化策略,確保您的應用在高負載下依然保持高效響應。我們將結合實際案例,深入講解索引、查詢優化、數據庫設計以及緩存等關鍵技術。 1.數據庫架構設計優化合理的數據庫架構是MySQL性能優化的基石。以下是一些核心原則:選擇合適的數據類型選擇最小的、符合需求的數據類型,既能節省存儲空間,又能提升數據處理速度

HadiDB:Python 中的輕量級、可水平擴展的數據庫 HadiDB:Python 中的輕量級、可水平擴展的數據庫 Apr 08, 2025 pm 06:12 PM

HadiDB:輕量級、高水平可擴展的Python數據庫HadiDB(hadidb)是一個用Python編寫的輕量級數據庫,具備高度水平的可擴展性。安裝HadiDB使用pip安裝:pipinstallhadidb用戶管理創建用戶:createuser()方法創建一個新用戶。 authentication()方法驗證用戶身份。 fromhadidb.operationimportuseruser_obj=user("admin","admin")user_obj.

Navicat查看MongoDB數據庫密碼的方法 Navicat查看MongoDB數據庫密碼的方法 Apr 08, 2025 pm 09:39 PM

直接通過 Navicat 查看 MongoDB 密碼是不可能的,因為它以哈希值形式存儲。取回丟失密碼的方法:1. 重置密碼;2. 檢查配置文件(可能包含哈希值);3. 檢查代碼(可能硬編碼密碼)。

mysql workbench 可以連接到 mariadb 嗎 mysql workbench 可以連接到 mariadb 嗎 Apr 08, 2025 pm 02:33 PM

MySQL Workbench 可以連接 MariaDB,前提是配置正確。首先選擇 "MariaDB" 作為連接器類型。在連接配置中,正確設置 HOST、PORT、USER、PASSWORD 和 DATABASE。測試連接時,檢查 MariaDB 服務是否啟動,用戶名和密碼是否正確,端口號是否正確,防火牆是否允許連接,以及數據庫是否存在。高級用法中,使用連接池技術優化性能。常見錯誤包括權限不足、網絡連接問題等,調試錯誤時仔細分析錯誤信息和使用調試工具。優化網絡配置可以提升性能

mysql 需要服務器嗎 mysql 需要服務器嗎 Apr 08, 2025 pm 02:12 PM

對於生產環境,通常需要一台服務器來運行 MySQL,原因包括性能、可靠性、安全性和可擴展性。服務器通常擁有更強大的硬件、冗餘配置和更嚴格的安全措施。對於小型、低負載應用,可在本地機器運行 MySQL,但需謹慎考慮資源消耗、安全風險和維護成本。如需更高的可靠性和安全性,應將 MySQL 部署到雲服務器或其他服務器上。選擇合適的服務器配置需要根據應用負載和數據量進行評估。

See all articles