首頁 > 後端開發 > Python教學 > Python中的神經網路實例

Python中的神經網路實例

WBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWBOYWB
發布: 2023-06-10 13:21:45
原創
2400 人瀏覽過

Python一向以其簡捷、靈活的語法和強大的生態系統和庫被廣泛使用和喜愛,其中包括科學計算和機器學習這樣的領域。神經網路在機器學習領域有著至關重要的作用,可用於電腦視覺、自然語言處理、推薦系統等多個領域。本文將介紹Python中的神經網絡,並給予一些例子。

什麼是神經網路

神經網路是一種深度學習的模型,具有模擬動物神經系統的特性。神經網路由多個神經元組成,每個神經元相當於一個函數,其輸入是來自其他神經元的輸出,透過激活函數進行處理後產生輸出。神經網路利用反向傳播演算法不斷調整權重和偏置,使模型能夠更好地適應數據並進行預測或分類。

TensorFlow

TensorFlow是Google推出的一款流行的深度學習框架,用於建立神經網路和其他機器學習演算法。 TensorFlow最初是為內部Google研究人員開發的,在開源之後,它迅速成為了最受歡迎的深度學習框架之一。

在TensorFlow中,我們可以使用以下步驟來建立神經網路:

  1. #準備資料集:必須將資料集分為訓練集和測試集兩部分。訓練集是用來訓練模型的,而測試集是用來測試模型的準確性的。
  2. 建立神經網路:可以使用Python編寫實作神經網路。可以使用TensorFlow API建構神經網路模型。
  3. 訓練模型:對神經網路模型進行訓練,以預測新資料的輸出。可以透過使用隨機梯度下降演算法並進行反向傳播,並不斷更新模型的權重和偏壓。
  4. 測試模型:使用測試集對模型進行測試,以確定其準確性。可以使用不同的指標來評估模型的效果,例如精確度、召回率和F1得分。

現在,我們將介紹兩個使用TensorFlow實作的神經網路範例。

神經網路實例1:手寫數字辨識

手寫數字辨識是電腦視覺領域的重要問題,神經網路在這個問題上取得了很好的效果。在TensorFlow中,可以使用MNIST資料集來訓練神經網絡,該資料集包含60000張28x28的灰階影像和對應的標籤。

首先,我們要安裝TensorFlow和NumPy函式庫。以下是手寫數字辨識的完整程式碼:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#创建模型
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

#定义损失函数和优化器
y_actual = tf.placeholder(tf.float32, [None, 10])
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_actual, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

#训练模型
sess = tf.Session()
sess.run(init)
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_actual: batch_ys})

#评估模型
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_actual,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_actual: mnist.test.labels}))
登入後複製

在這個實例中,我們先準備了資料集MNIST,然後建立了一個包含784個輸入和10個輸出的簡單神經網路模型。接下來,我們定義了損失函數和最佳化器,將訓練資料輸入模型進行訓練。最後,我們對測試數據進行測試評估,並獲得了準確性為92.3%的結果。

神經網路實例2:垃圾郵件過濾器

現在幾乎每個人都使用郵件系統,但所有人都面臨垃圾郵件問題。垃圾郵件過濾器是一種能夠檢查一封電子郵件是否為垃圾郵件的程式。讓我們看看如何使用神經網路來建立一個垃圾郵件過濾器。

首先,我們需要準備一個垃圾郵件資料集,包括已標記為垃圾郵件和非垃圾郵件的郵件。請注意,在建立垃圾郵件過濾器時,將有兩個類別的郵件:非垃圾郵件和垃圾郵件。

以下是垃圾郵件過濾器的完整程式碼:

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

#读取数据集
data = pd.read_csv('spam.csv')
data = data.drop(['Unnamed: 2', 'Unnamed: 3', 'Unnamed: 4'], axis=1)

#转换标签
data['v1'] = data['v1'].map({'ham': 0, 'spam': 1})

#划分数据集
X_train, X_test, y_train, y_test = train_test_split(data['v2'], data['v1'], test_size=0.33, random_state=42)

#创建神经网络模型
max_words = 1000
tokenize = tf.keras.preprocessing.text.Tokenizer(num_words=max_words, char_level=False)
tokenize.fit_on_texts(X_train)

x_train = tokenize.texts_to_matrix(X_train)
x_test = tokenize.texts_to_matrix(X_test)

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(512, input_shape=(max_words,), activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(256, activation='sigmoid'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

#训练模型
model.fit(x_train, y_train,
          batch_size=32,
          epochs=5,
          validation_data=(x_test, y_test))

#评估模型
y_predict = model.predict(x_test)
print("Accuracy:", accuracy_score(y_test, y_predict.round()))
登入後複製

在這個實例中,我們使用的是sklearn的train_test_split()方法用於劃分資料集,然後使用Keras庫中的文字預處理工具將資料集轉換為矩陣(one-hot編碼)。接下來,我們使用Sequential來聲明神經元並設定其參數。最後,我們使用訓練後的模型對測試資料進行預測,並評估得到了一個準確度為98.02%的結果。

結論

Python中的神經網路是一種功能強大的技術,可用於多種應用,例如圖像識別,垃圾郵件過濾器等。使用TensorFlow,我們可以輕鬆地建立、訓練和測試神經網路模型,並獲得令人滿意的結果。隨著人們對機器學習需求的增長,神經網路技術將成為更重要的工具,在未來的應用情境中得到更廣泛的應用。

以上是Python中的神經網路實例的詳細內容。更多資訊請關注PHP中文網其他相關文章!

相關標籤:
來源:php.cn
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
最新問題
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板