目錄
一、论文概述
1.分类嵌入(Categorical Embeddings)
2.上下文嵌入(Contextual Embeddings)
3.TabTransformer架构
4.转换器
5.多头注意力机制(Multi-head-attention)
6.簡短回顧
7.試驗結果展示
二、建立我們自己的範例程式
1.資料預處理
2.建構TabTransformer模型
3.評價
三、结论
译者介绍
首頁 科技週邊 人工智慧 TabTransformer轉換器提升多層感知機效能深度解析

TabTransformer轉換器提升多層感知機效能深度解析

Apr 17, 2023 pm 03:25 PM
機器學習 轉換器 nlp

​如今,转换器(Transformers)成为大多数先进的自然语言处理(NLP)和计算机视觉(CV)体系结构中的关键模块。然而,表格式数据领域仍然主要以梯度提升决策树(GBDT)算法为主导。于是,有人试图弥合这一差距。其中,第一篇基于转换器的表格数据建模论文是由Huang等人于2020年发表的论文《TabTransformer:使用上下文嵌入的表格数据建模》。

本文旨在提供该论文内容的基本展示,同时将深入探讨TabTransformer模型的实现细节,并向您展示如何针对我们自己的数据来具体使用TabTransformer。

一、论文概述

上述论文的主要思想是,如果使用转换器将常规的分类嵌入转换为上下文嵌入,那么,常规的多层感知器(MLP)的性能将会得到显著提高。接下来,让我们更为深入地理解这一描述。

1.分类嵌入(Categorical Embeddings)

在深度学习模型中,使用分类特征的经典方法是训练其嵌入性。这意味着,每个类别值都有一个唯一的密集型向量表示,并且可以传递给下一层。例如,由下图您可以看到,每个分类特征都使用一个四维数组表示。然后,这些嵌入与数字特征串联,并用作MLP的输入。

TabTransformer轉換器提升多層感知機效能深度解析

带有分类嵌入的MLP

2.上下文嵌入(Contextual Embeddings)

论文作者认为,分类嵌入缺乏上下文含义,即它们并没有对分类变量之间的任何交互和关系信息进行编码。为了将嵌入内容更加具体化,有人建议使用NLP领域当前所使用的转换器来实现这一目的。

TabTransformer轉換器提升多層感知機效能深度解析

TabTransformer转换器中的上下文嵌入

为了以可视化方式形象地展示上述想法,我们不妨考虑下面这个训练后得到的上下文嵌入图像。其中,突出显示了两个分类特征:关系(黑色)和婚姻状况(蓝色)。这些特征是相关的;所以,“已婚(Married)”、“丈夫(Husband)”和“妻子(Wife)”的值应该在向量空间中彼此接近,即使它们来自不同的变量。

TabTransformer轉換器提升多層感知機效能深度解析

经训练后的TabTransformer转换器嵌入结果示例

通过上图中经过训练的上下文嵌入结果,我们可以看到,“已婚(Married)”的婚姻状况更接近“丈夫(Husband)”和“妻子(Wife)”的关系水平,而“未结婚(non-married)”的分类值则来自右侧的单独数据簇。这种类型的上下文使这样的嵌入更加有用,而使用简单形式的类别嵌入技术是不可能实现这种效果的。

3.TabTransformer架构

为了达到上述目的,论文作者提出了以下架构:

TabTransformer轉換器提升多層感知機效能深度解析

TabTransformer转换器架构示意图

(摘取自Huang等人2020年发表的论文)

我们可以将此体系结构分解为5个步骤:

  • 标准化数字特征并向前传递
  • 嵌入分类特征
  • 嵌入经过N次转换器块处理,以便获得上下文嵌入
  • 把上下文分类嵌入与数字特征进行串联
  • 通过MLP进行串联获得所需的预测

虽然模型架构非常简单,但论文作者表示,添加转换器层可以显著提高计算性能。当然,所有的“魔术”发生在这些转换器块内部;所以,接下来让我们更加详细地研究一下其中的实现过程。

4.转换器

TabTransformer轉換器提升多層感知機效能深度解析

转换器(Transformer)架构示意

(选自Vaswani等人于2017年发表的论文)

您可能以前见过转换器架构,但为了快速介绍起见,请记住该转换器是由编码器和解码器两部分组成(见上图)。对于TabTransformer,我们只关心将输入的嵌入内容上下文化的编码器部分(解码器部分将这些嵌入内容转换为最终输出结果)。但它到底是如何做到的呢?答案是——多头注意力机制。

5.多头注意力机制(Multi-head-attention)

引用我最喜歡的關於注意力機制的文章的描述,是這樣的:

#「自我關注(self attention)背後的關鍵概念是,這種機制允許神經網路學習如何在輸入序列的各個片段之間以最好的路由方案進行資訊調度。」

換句話說,自我關注(self-attention)有助於模型找出在表示某個單字/類別時,輸入的哪些部分更重要,哪些部分相對不重要。為此,我強烈建議您閱讀一下上面引用的這篇文章,以便對自我關注為什麼如此有效有一個更直觀的理解。

TabTransformer轉換器提升多層感知機效能深度解析

多頭注意力機制

(選自Vaswani等人於2017年發表的論文)

#注意力是透過3個學習過的矩陣來計算的-Q、K和V,它們代表查詢(Query)、鍵(Key)和值(Value)。首先,我們將矩陣Q和K相乘得到注意力矩陣。此矩陣被縮放並通過softmax層傳遞。然後,我們將其乘以V矩陣,得出最終值。為了更直觀地理解起見,請考慮下面的示意圖,它顯示了我們如何使用矩陣Q、K和V實現從輸入嵌入轉換到上下文嵌入。

TabTransformer轉換器提升多層感知機效能深度解析

自我關注流程視覺化

透過重複流程h次(使用不同的Q、K 、V矩陣),我們就能夠得到多個脈絡嵌入,它們形成我們最終的多頭注意力。

6.簡短回顧

讓我們總結一下上面所介紹的內容:

  • 簡單的分類嵌入不包含上下文訊息
  • 透過轉換器編碼器傳遞分類嵌入,我們就能夠將嵌入上下文化
  • 轉換器部分能夠將嵌入上下文化,因為它使用了多頭注意力機制
  • 多頭注意力機制在編碼變數時使用矩陣Q、K和V來尋找有用的交互作用和相關性資訊
  • 在TabTransformer中,被上下文化的嵌入與數位輸入相串聯,並透過一個簡單的MLP輸出預測

#雖然TabTransformer背後的想法很簡單,但您可能需要一些時間才能掌握注意力機制。因此,我強烈建議您重新閱讀以上解釋。如果您感到有些迷茫,請認真閱讀本文中所有建議的連結相關內容。我保證,做到這些後,您就不難搞明白注意力機制的原理了。

7.試驗結果展示

TabTransformer轉換器提升多層感知機效能深度解析

#結果資料(選自Huang等人2020年發表的論文)

根據報告的結果,TabTransformer轉換器優於所有其他深度學習表格模型,此外,它接近GBDT的性能水平,這非常令人鼓舞。該模型對缺失資料和雜訊資料也相對穩健,並且在半監督環境下優於其他模型。然而,這些資料集顯然不是詳盡無遺的,正如以後發表的一些相關論文所證實的那樣,仍有很大的改進空間。

二、建立我們自己的範例程式

#現在,讓我們最終來確定如何將模型應用於我們自己的資料。接下來的範例數據取自著名的Tabular Playground Kaggle比賽。為了方便使用TabTransformer轉換器,我建立了一個tabtransformertf套件。它可以使用以下pip命令進行安裝:

pip install tabtransformertf
登入後複製

並允許我們使用該模型,而無需進行大量預處理。

1.資料預處理

第一步是設定適當的資料類型,並將我們的訓練和驗證資料轉換為TF數據集。其中,前面安裝的軟體包中就提供了一個很好的實用程式可以做到這一點。

from tabtransformertf.utils.preprocessing import df_to_dataset, build_categorical_prep

# 设置数据类型
train_data[CATEGORICAL_FEATURES] = train_data[CATEGORICAL_FEATURES].astype(str)
val_data[CATEGORICAL_FEATURES] = val_data[CATEGORICAL_FEATURES].astype(str)

train_data[NUMERIC_FEATURES] = train_data[NUMERIC_FEATURES].astype(float)
val_data[NUMERIC_FEATURES] = val_data[NUMERIC_FEATURES].astype(float)

# 转换成TF数据集
train_dataset = df_to_dataset(train_data[FEATURES + [LABEL]], LABEL, batch_size=1024)
val_dataset = df_to_dataset(val_data[FEATURES + [LABEL]], LABEL, shuffle=False, batch_size=1024)
登入後複製

下一步是為分類資料準備預處理層。該分類資料稍後將傳遞給我們的主模型。

from tabtransformertf.utils.preprocessing import build_categorical_prep

category_prep_layers = build_categorical_prep(train_data, CATEGORICAL_FEATURES)

# 输出结果是一个字典结构,其中键部分是特征名称,值部分是StringLookup层
# category_prep_layers ->
# {'product_code': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05d28ee4e0>,
#'attribute_0': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05ca4fb908>,
#'attribute_1': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05ca4da5f8>}
登入後複製

這就是預處理!現在,我們可以開始建立模型了。

2.建構TabTransformer模型

#初始化模型很容易。其中,有幾個參數需要指定,但最重要的幾個參數是:embeding_dim、depth和heads。所有參數都是在超參數調整後選擇的。

from tabtransformertf.models.tabtransformer import TabTransformer

tabtransformer = TabTransformer(
numerical_features = NUMERIC_FEATURES,# 带有数字特征名称的列表
categorical_features = CATEGORICAL_FEATURES, # 带有分类特征名称的列表
categorical_lookup=category_prep_layers, # 带StringLookup层的Dict
numerical_discretisers=None,# None代表我们只是简单地传递数字特征
embedding_dim=32,# 嵌入维数
out_dim=1,# Dimensionality of output (binary task)
out_activatinotallow='sigmoid',# 输出层激活
depth=4,# 转换器块层的个数
heads=8,# 转换器块中注意力头的个数
attn_dropout=0.1,# 在转换器块中的丢弃率
ff_dropout=0.1,# 在最后MLP中的丢弃率
mlp_hidden_factors=[2, 4],# 我们为每一层划分最终嵌入的因子
use_column_embedding=True,#如果我们想使用列嵌入,设置此项为真
)

# 模型运行中摘要输出:
# 总参数个数: 1,778,884
# 可训练的参数个数: 1,774,064
# 不可训练的参数个数: 4,820
登入後複製

模型初始化後,我們可以像其他Keras模型一樣安裝它。訓練參數也可以調整,所以可以隨意調整學習速度和提前停止。

LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.0001
NUM_EPOCHS = 1000

optimizer = tfa.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

tabtransformer.compile(
optimizer = optimizer,
loss = tf.keras.losses.BinaryCrossentropy(),
metrics= [tf.keras.metrics.AUC(name="PR AUC", curve='PR')],
)

out_file = './tabTransformerBasic'
checkpoint = ModelCheckpoint(
out_file, mnotallow="val_loss", verbose=1, save_best_notallow=True, mode="min"
)
early = EarlyStopping(mnotallow="val_loss", mode="min", patience=10, restore_best_weights=True)
callback_list = [checkpoint, early]

history = tabtransformer.fit(
train_dataset,
epochs=NUM_EPOCHS,
validation_data=val_dataset,
callbacks=callback_list
)
登入後複製

3.評價

競賽中最關鍵的指標是ROC AUC。因此,讓我們將其與PR AUC指標一起輸出來評估模型的表現。

val_preds = tabtransformer.predict(val_dataset)

print(f"PR AUC: {average_precision_score(val_data['isFraud'], val_preds.ravel())}")
print(f"ROC AUC: {roc_auc_score(val_data['isFraud'], val_preds.ravel())}")

# PR AUC: 0.26
# ROC AUC: 0.58
登入後複製

您也可以自己给测试集评分,然后将结果值提交给Kaggle官方。我现在选择的这个解决方案使我跻身前35%,这并不坏,但也不太好。那么,为什么TabTransfromer在上述方案中表现不佳呢?可能有以下几个原因:

  • 数据集太小,而深度学习模型以需要大量数据著称
  • TabTransformer很容易在表格式数据示例领域出现过拟合
  • 没有足够的分类特征使模型有用

三、结论

本文探讨了TabTransformer背后的主要思想,并展示了如何使用Tabtransformertf包来具体应用此转换器。

归纳起来看,TabTransformer的确是一种有趣的体系结构,它在当时的表现明显优于大多数深度表格模型。它的主要优点是将分类嵌入语境化,从而增强其表达能力。它使用在分类特征上的多头注意力机制来实现这一点,而这是在表格数据领域使用转换器的第一个应用实例。

TabTransformer体系结构的一个明显缺点是,数字特征被简单地传递到最终的MLP层。因此,它们没有语境化,它们的价值也没有在分类嵌入中得到解释。在下一篇文章中,我将探讨如何修复此缺陷并进一步提高性能。

译者介绍

朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。

原文链接:https://towardsdatascience.com/transformers-for-tabular-data-tabtransformer-deep-dive-5fb2438da820?source=collection_home---------4----------------------------

以上是TabTransformer轉換器提升多層感知機效能深度解析的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

熱門話題

Java教學
1662
14
CakePHP 教程
1419
52
Laravel 教程
1311
25
PHP教程
1261
29
C# 教程
1234
24
一文帶您了解SHAP:機器學習的模型解釋 一文帶您了解SHAP:機器學習的模型解釋 Jun 01, 2024 am 10:58 AM

在機器學習和資料科學領域,模型的可解釋性一直是研究者和實踐者關注的焦點。隨著深度學習和整合方法等複雜模型的廣泛應用,理解模型的決策過程變得尤為重要。可解釋人工智慧(ExplainableAI|XAI)透過提高模型的透明度,幫助建立對機器學習模型的信任和信心。提高模型的透明度可以透過多種複雜模型的廣泛應用等方法來實現,以及用於解釋模型的決策過程。這些方法包括特徵重要性分析、模型預測區間估計、局部可解釋性演算法等。特徵重要性分析可以透過評估模型對輸入特徵的影響程度來解釋模型的決策過程。模型預測區間估計

透過學習曲線辨識過擬合和欠擬合 透過學習曲線辨識過擬合和欠擬合 Apr 29, 2024 pm 06:50 PM

本文將介紹如何透過學習曲線來有效辨識機器學習模型中的過度擬合和欠擬合。欠擬合和過擬合1、過擬合如果一個模型對資料進行了過度訓練,以至於它從中學習了噪聲,那麼這個模型就被稱為過擬合。過度擬合模型非常完美地學習了每一個例子,所以它會錯誤地分類一個看不見的/新的例子。對於一個過度擬合的模型,我們會得到一個完美/接近完美的訓練集分數和一個糟糕的驗證集/測試分數。略有修改:"過擬合的原因:用一個複雜的模型來解決一個簡單的問題,從資料中提取雜訊。因為小資料集作為訓練集可能無法代表所有資料的正確表示。"2、欠擬合如

人工智慧在太空探索和人居工程中的演變 人工智慧在太空探索和人居工程中的演變 Apr 29, 2024 pm 03:25 PM

1950年代,人工智慧(AI)誕生。當時研究人員發現機器可以執行類似人類的任務,例如思考。後來,在1960年代,美國國防部資助了人工智慧,並建立了實驗室進行進一步開發。研究人員發現人工智慧在許多領域都有用武之地,例如太空探索和極端環境中的生存。太空探索是對宇宙的研究,宇宙涵蓋了地球以外的整個宇宙空間。太空被歸類為極端環境,因為它的條件與地球不同。要在太空中生存,必須考慮許多因素,並採取預防措施。科學家和研究人員認為,探索太空並了解一切事物的現狀有助於理解宇宙的運作方式,並為潛在的環境危機

通透!機器學習各大模型原理的深度剖析! 通透!機器學習各大模型原理的深度剖析! Apr 12, 2024 pm 05:55 PM

通俗來說,機器學習模型是一種數學函數,它能夠將輸入資料映射到預測輸出。更具體地說,機器學習模型是一種透過學習訓練數據,來調整模型參數,以最小化預測輸出與真實標籤之間的誤差的數學函數。在機器學習中存在多種模型,例如邏輯迴歸模型、決策樹模型、支援向量機模型等,每種模型都有其適用的資料類型和問題類型。同時,不同模型之間存在著許多共通性,或者說有一條隱藏的模型演化的路徑。將聯結主義的感知機為例,透過增加感知機的隱藏層數量,我們可以將其轉化為深度神經網路。而對感知機加入核函數的話就可以轉換為SVM。這一

使用C++實現機器學習演算法:常見挑戰及解決方案 使用C++實現機器學習演算法:常見挑戰及解決方案 Jun 03, 2024 pm 01:25 PM

C++中機器學習演算法面臨的常見挑戰包括記憶體管理、多執行緒、效能最佳化和可維護性。解決方案包括使用智慧指標、現代線程庫、SIMD指令和第三方庫,並遵循程式碼風格指南和使用自動化工具。實作案例展示如何利用Eigen函式庫實現線性迴歸演算法,有效地管理記憶體和使用高效能矩陣操作。

你所不知道的機器學習五大學派 你所不知道的機器學習五大學派 Jun 05, 2024 pm 08:51 PM

機器學習是人工智慧的重要分支,它賦予電腦從數據中學習的能力,並能夠在無需明確編程的情況下改進自身能力。機器學習在各個領域都有廣泛的應用,從影像辨識和自然語言處理到推薦系統和詐欺偵測,它正在改變我們的生活方式。機器學習領域存在著多種不同的方法和理論,其中最具影響力的五種方法被稱為「機器學習五大派」。這五大派分別為符號派、聯結派、進化派、貝葉斯派和類推學派。 1.符號學派符號學(Symbolism),又稱符號主義,強調利用符號進行邏輯推理和表達知識。該學派認為學習是一種逆向演繹的過程,透過現有的

Flash Attention穩定嗎? Meta、哈佛發現其模型權重偏差呈現數量級波動 Flash Attention穩定嗎? Meta、哈佛發現其模型權重偏差呈現數量級波動 May 30, 2024 pm 01:24 PM

MetaFAIR聯合哈佛優化大規模機器學習時所產生的資料偏差,提供了新的研究架構。據所周知,大語言模型的訓練常常需要數月的時間,使用數百甚至上千個GPU。以LLaMA270B模型為例,其訓練總共需要1,720,320個GPU小時。由於這些工作負載的規模和複雜性,導致訓練大模型存在著獨特的系統性挑戰。最近,許多機構在訓練SOTA生成式AI模型時報告了訓練過程中的不穩定情況,它們通常以損失尖峰的形式出現,例如Google的PaLM模型訓練過程中出現了多達20次的損失尖峰。數值偏差是造成這種訓練不準確性的根因,

可解釋性人工智慧:解釋複雜的AI/ML模型 可解釋性人工智慧:解釋複雜的AI/ML模型 Jun 03, 2024 pm 10:08 PM

譯者|李睿審校|重樓人工智慧(AI)和機器學習(ML)模型如今變得越來越複雜,這些模型產生的產出是黑盒子-無法向利害關係人解釋。可解釋性人工智慧(XAI)致力於透過讓利害關係人理解這些模型的工作方式來解決這個問題,確保他們理解這些模型實際上是如何做出決策的,並確保人工智慧系統中的透明度、信任度和問責制來解決這個問題。本文探討了各種可解釋性人工智慧(XAI)技術,以闡明它們的基本原理。可解釋性人工智慧至關重要的幾個原因信任度和透明度:為了讓人工智慧系統被廣泛接受和信任,使用者需要了解決策是如何做出的

See all articles