ホームページ > バックエンド開発 > Python チュートリアル > Tensorflow 分類子プロジェクトのカスタム データを読み取る方法の紹介 (コード例)

Tensorflow 分類子プロジェクトのカスタム データを読み取る方法の紹介 (コード例)

不言
リリース: 2019-02-11 10:28:55
転載
2602 人が閲覧しました

この記事では、Tensorflow 分類子プロジェクトのカスタム データを読み取る方法 (コード例) を紹介します。これには特定の参考値があります。必要な友人はそれを参照できます。お役に立てば幸いです。あなた。

Tensorflow 分類子プロジェクトのカスタム データの読み取り

Tensorflow 公式 Web サイトのデモに従って分類子プロジェクトのコードを入力した後、操作は成功しました。悪い。しかし、最終的には自分でデータをトレーニングする必要があるため、カスタム データを読み込む準備をしようとしましたが、fashion_mnist.load_data() は詳細な読み込みプロセスなしでデモに登場するだけでした。読み取りプロセス。ここに記録されます。

まず、使用する必要があるモジュールについて言及します:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
ログイン後にコピー

画像分類子プロジェクト。まず、処理する画像の解像度を決定します。ここでの例は 30 ピクセルです:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30
ログイン後にコピー

次に、写真のディレクトリを決定します:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")
ログイン後にコピー

ディレクトリの下の構造は次のとおりです:

Tensorflow 分類子プロジェクトのカスタム データを読み取る方法の紹介 (コード例)

対応する label.txt は次のとおりです。次のように:

动漫
风景
美女
物语
樱花
ログイン後にコピー

Next は次のように label.txt に接続されます:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))
ログイン後にコピー

わかりやすくするために、numpy のloadtxt 関数を直接使用して直接ロードします。

その後、画像データを正式に加工し、中にコメントを書きます:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
    for index, name in enumerate(class_names):
        # 这里是拼接后的子目录path
        classpath = os.path.join(image_path, name)
        # 先判断一下是否是目录
        if not os.path.isdir(classpath):
            continue
        # limit是测试时候用的这里可以去除
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            # 这里是拼接后的待处理的图片path
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            # 利用Image打开图片
            img = Image.open(imagepath)
            # 缩放到你最初确定要处理的图片分辨率大小
            img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
            # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
            img = img.convert("L")
            # 转为numpy数组
            img = np.array(img)
            # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
            img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
            # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
            labels.append([index])
            # 添加到images中,最后统一处理
            images.append(img)
            # 循环中一些状态的输出,可以去除
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    # 最后一次性将images和labels都转换成numpy数组
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    # 如果存在序列化号的数据,便直接读取,提高速度
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels
ログイン後にコピー

ここまでで、元データの加工と前処理は完了です。デモと同様に、 fashion_mnist.load_data() 返される結果は同じです。コードは次のとおりです:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
ログイン後にコピー

関連情報を出力するメソッドもここに添付されています:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")
ログイン後にコピー

その後は正規化することを忘れないでください:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0
ログイン後にコピー

最後に、関連情報を印刷する方法が添付されています: データを定義する完全なコード:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000
IMG_SIZE_X = 30
IMG_SIZE_Y = 30
np.random.seed(9277)
image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)
label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))
print('Load class names')
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    for index, name in enumerate(class_names):
        classpath = os.path.join(image_path, name)
        if not os.path.isdir(classpath):
            continue
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            img = Image.open(imagepath)
            img = img.resize((30, 30))
            img = img.convert("L")
            img = np.array(img)
            img = np.reshape(img, (1, 30, 30))
            # img = skimage.io.imread(imagepath, as_grey=True)
            # if img.shape[2] != 3:
            #     print("{} shape is {}".format(image_name, img.shape))
            #     continue
            # data = transform.resize(img, (IMG_SIZE_X, IMG_SIZE_Y))
            labels.append([index])
            images.append(img)
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

# 归一化
# 我们将这些值缩小到 0 到 1 之间,然后将其馈送到神经网络模型。为此,将图像组件的数据类型从整数转换为浮点数,然后除以 255。以下是预处理图像的函数:
# 务必要以相同的方式对训练集和测试集进行预处理:
print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0
ログイン後にコピー

以上がTensorflow 分類子プロジェクトのカスタム データを読み取る方法の紹介 (コード例)の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

関連ラベル:
ソース:segmentfault.com
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート