python如何實作決策樹演算法?(程式碼)
這篇文章帶給大家的內容是關於python如何實現決策樹演算法?(程式碼),有一定的參考價值,有需要的朋友可以參考一下,希望對你有幫助。
資料描述
每個資料項目儲存在清單中,最後一列儲存結果
多條資料項目形成資料集
data=[[d1,d2,d3...dn,result], [d1,d2,d3...dn,result], . . [d1,d2,d3...dn,result]]
決策樹資料結構
class DecisionNode: '''决策树节点 ''' def __init__(self,col=-1,value=None,results=None,tb=None,fb=None): '''初始化决策树节点 args: col -- 按数据集的col列划分数据集 value -- 以value作为划分col列的参照 result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果’:结果出现次数} rb,fb -- 代表左右子树 ''' self.col=col self.value=value self.results=results self.tb=tb self.fb=fb
決策樹分類的最終結果是將資料項目劃分出了若干子集,其中每個子集的結果都一樣,所以這裡採用{'結果':結果出現次數}的方式表達每個子集
def pideset(rows,column,value): '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分 返回两个数据集 ''' split_function=None #value是数值类型 if isinstance(value,int) or isinstance(value,float): #定义lambda函数当row[column]>=value时返回true split_function=lambda row:row[column]>=value #value是字符类型 else: #定义lambda函数当row[column]==value时返回true split_function=lambda row:row[column]==value #将数据集拆分成两个 set1=[row for row in rows if split_function(row)] set2=[row for row in rows if not split_function(row)] #返回两个数据集 return (set1,set2) def uniquecounts(rows): '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典 ''' results={} for row in rows: r=row[len(row)-1] if r not in results: results[r]=0 results[r]+=1 return results def giniimpurity(rows): '''返回rows数据集的基尼不纯度 ''' total=len(rows) counts=uniquecounts(rows) imp=0 for k1 in counts: p1=float(counts[k1])/total for k2 in counts: if k1==k2: continue p2=float(counts[k2])/total imp+=p1*p2 return imp def entropy(rows): '''返回rows数据集的熵 ''' from math import log log2=lambda x:log(x)/log(2) results=uniquecounts(rows) ent=0.0 for r in results.keys(): p=float(results[r])/len(rows) ent=ent-p*log2(p) return ent def build_tree(rows,scoref=entropy): '''构造决策树 ''' if len(rows)==0: return DecisionNode() current_score=scoref(rows) # 最佳信息增益 best_gain=0.0 # best_criteria=None #最佳划分 best_sets=None column_count=len(rows[0])-1 #遍历数据集的列,确定分割顺序 for col in range(0,column_count): column_values={} # 构造字典 for row in rows: column_values[row[col]]=1 for value in column_values.keys(): (set1,set2)=pideset(rows,col,value) p=float(len(set1))/len(rows) # 计算信息增益 gain=current_score-p*scoref(set1)-(1-p)*scoref(set2) if gain>best_gain and len(set1)>0 and len(set2)>0: best_gain=gain best_criteria=(col,value) best_sets=(set1,set2) # 如果划分的两个数据集熵小于原数据集,进一步划分它们 if best_gain>0: trueBranch=build_tree(best_sets[0]) falseBranch=build_tree(best_sets[1]) return DecisionNode(col=best_criteria[0],value=best_criteria[1], tb=trueBranch,fb=falseBranch) # 如果划分的两个数据集熵不小于原数据集,停止划分 else: return DecisionNode(results=uniquecounts(rows)) def print_tree(tree,indent=''): if tree.results!=None: print(str(tree.results)) else: print(str(tree.col)+':'+str(tree.value)+'? ') print(indent+'T->',end='') print_tree(tree.tb,indent+' ') print(indent+'F->',end='') print_tree(tree.fb,indent+' ') def getwidth(tree): if tree.tb==None and tree.fb==None: return 1 return getwidth(tree.tb)+getwidth(tree.fb) def getdepth(tree): if tree.tb==None and tree.fb==None: return 0 return max(getdepth(tree.tb),getdepth(tree.fb))+1 def drawtree(tree,jpeg='tree.jpg'): w=getwidth(tree)*100 h=getdepth(tree)*100+120 img=Image.new('RGB',(w,h),(255,255,255)) draw=ImageDraw.Draw(img) drawnode(draw,tree,w/2,20) img.save(jpeg,'JPEG') def drawnode(draw,tree,x,y): if tree.results==None: # Get the width of each branch w1=getwidth(tree.fb)*100 w2=getwidth(tree.tb)*100 # Determine the total space required by this node left=x-(w1+w2)/2 right=x+(w1+w2)/2 # Draw the condition string draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0)) # Draw links to the branches draw.line((x,y,left+w1/2,y+100),fill=(255,0,0)) draw.line((x,y,right-w2/2,y+100),fill=(255,0,0)) # Draw the branch nodes drawnode(draw,tree.fb,left+w1/2,y+100) drawnode(draw,tree.tb,right-w2/2,y+100) else: txt=' \n'.join(['%s:%d'%v for v in tree.results.items()]) draw.text((x-20,y),txt,(0,0,0))
對測試資料進行分類(附帶處理缺失資料)
def mdclassify(observation,tree): '''对缺失数据进行分类 args: observation -- 发生信息缺失的数据项 tree -- 训练完成的决策树 返回代表该分类的结果字典 ''' # 判断数据是否到达叶节点 if tree.results!=None: # 已经到达叶节点,返回结果result return tree.results else: # 对数据项的col列进行分析 v=observation[tree.col] # 若col列数据缺失 if v==None: #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典 tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb) # 分别以结果占总数比例计算得到左右子树的权重 tcount=sum(tr.values()) fcount=sum(fr.values()) tw=float(tcount)/(tcount+fcount) fw=float(fcount)/(tcount+fcount) result={} # 计算左右子树的加权平均 for k,v in tr.items(): result[k]=v*tw for k,v in fr.items(): # fr的结果k有可能并不在tr中,在result中初始化k if k not in result: result[k]=0 # fr的结果累加到result中 result[k]+=v*fw return result # col列没有缺失,继续沿决策树分类 else: if isinstance(v,int) or isinstance(v,float): if v>=tree.value: branch=tree.tb else: branch=tree.fb else: if v==tree.value: branch=tree.tb else: branch=tree.fb return mdclassify(observation,branch) tree=build_tree(my_data) print(mdclassify(['google',None,'yes',None],tree)) print(mdclassify(['google','France',None,None],tree))
決策樹剪枝
def prune(tree,mingain): '''对决策树进行剪枝 args: tree -- 决策树 mingain -- 最小信息增益 返回 ''' # 修剪非叶节点 if tree.tb.results==None: prune(tree.tb,mingain) if tree.fb.results==None: prune(tree.fb,mingain) #合并两个叶子节点 if tree.tb.results!=None and tree.fb.results!=None: tb,fb=[],[] for v,c in tree.tb.results.items(): tb+=[[v]]*c for v,c in tree.fb.results.items(): fb+=[[v]]*c #计算熵减少情况 delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2) #熵的增加量小于mingain,可以合并分支 if delta<mingain: tree.tb,tree.fb=None,None tree.results=uniquecounts(tb+fb)
以上是python如何實作決策樹演算法?(程式碼)的詳細內容。更多資訊請關注PHP中文網其他相關文章!

熱AI工具

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool
免費脫衣圖片

Clothoff.io
AI脫衣器

AI Hentai Generator
免費產生 AI 無盡。

熱門文章

熱工具

記事本++7.3.1
好用且免費的程式碼編輯器

SublimeText3漢化版
中文版,非常好用

禪工作室 13.0.1
強大的PHP整合開發環境

Dreamweaver CS6
視覺化網頁開發工具

SublimeText3 Mac版
神級程式碼編輯軟體(SublimeText3)

熱門話題

MySQL 有免費的社區版和收費的企業版。社區版可免費使用和修改,但支持有限,適合穩定性要求不高、技術能力強的應用。企業版提供全面商業支持,適合需要穩定可靠、高性能數據庫且願意為支持買單的應用。選擇版本時考慮的因素包括應用關鍵性、預算和技術技能。沒有完美的選項,只有最合適的方案,需根據具體情況謹慎選擇。

文章介紹了MySQL數據庫的上手操作。首先,需安裝MySQL客戶端,如MySQLWorkbench或命令行客戶端。 1.使用mysql-uroot-p命令連接服務器,並使用root賬戶密碼登錄;2.使用CREATEDATABASE創建數據庫,USE選擇數據庫;3.使用CREATETABLE創建表,定義字段及數據類型;4.使用INSERTINTO插入數據,SELECT查詢數據,UPDATE更新數據,DELETE刪除數據。熟練掌握這些步驟,並學習處理常見問題和優化數據庫性能,才能高效使用MySQL。

MySQL 可在無需網絡連接的情況下運行,進行基本的數據存儲和管理。但是,對於與其他系統交互、遠程訪問或使用高級功能(如復制和集群)的情況,則需要網絡連接。此外,安全措施(如防火牆)、性能優化(選擇合適的網絡連接)和數據備份對於連接到互聯網的 MySQL 數據庫至關重要。

MySQL數據庫性能優化指南在資源密集型應用中,MySQL數據庫扮演著至關重要的角色,負責管理海量事務。然而,隨著應用規模的擴大,數據庫性能瓶頸往往成為製約因素。本文將探討一系列行之有效的MySQL性能優化策略,確保您的應用在高負載下依然保持高效響應。我們將結合實際案例,深入講解索引、查詢優化、數據庫設計以及緩存等關鍵技術。 1.數據庫架構設計優化合理的數據庫架構是MySQL性能優化的基石。以下是一些核心原則:選擇合適的數據類型選擇最小的、符合需求的數據類型,既能節省存儲空間,又能提升數據處理速度

HadiDB:輕量級、高水平可擴展的Python數據庫HadiDB(hadidb)是一個用Python編寫的輕量級數據庫,具備高度水平的可擴展性。安裝HadiDB使用pip安裝:pipinstallhadidb用戶管理創建用戶:createuser()方法創建一個新用戶。 authentication()方法驗證用戶身份。 fromhadidb.operationimportuseruser_obj=user("admin","admin")user_obj.

直接通過 Navicat 查看 MongoDB 密碼是不可能的,因為它以哈希值形式存儲。取回丟失密碼的方法:1. 重置密碼;2. 檢查配置文件(可能包含哈希值);3. 檢查代碼(可能硬編碼密碼)。

MySQL Workbench 可以連接 MariaDB,前提是配置正確。首先選擇 "MariaDB" 作為連接器類型。在連接配置中,正確設置 HOST、PORT、USER、PASSWORD 和 DATABASE。測試連接時,檢查 MariaDB 服務是否啟動,用戶名和密碼是否正確,端口號是否正確,防火牆是否允許連接,以及數據庫是否存在。高級用法中,使用連接池技術優化性能。常見錯誤包括權限不足、網絡連接問題等,調試錯誤時仔細分析錯誤信息和使用調試工具。優化網絡配置可以提升性能

對於生產環境,通常需要一台服務器來運行 MySQL,原因包括性能、可靠性、安全性和可擴展性。服務器通常擁有更強大的硬件、冗餘配置和更嚴格的安全措施。對於小型、低負載應用,可在本地機器運行 MySQL,但需謹慎考慮資源消耗、安全風險和維護成本。如需更高的可靠性和安全性,應將 MySQL 部署到雲服務器或其他服務器上。選擇合適的服務器配置需要根據應用負載和數據量進行評估。
