ROBOFLOW - 使用 python 进行训练和测试

WBOY
发布: 2024-08-27 06:01:32
原创
932 人浏览过

Roboflow 是一个用于注释图像以用于对象检测 AI 的平台。

我将这个平台用于 C2SMR c2smr.fr,我的海上救援计算机视觉协会。

在本文中,我将向您展示如何使用该平台并使用 python 训练模型。

您可以在我的 github 上找到更多示例代码:https://github.com/C2SMR/Detector


I - 数据集

要创建数据集,请访问 https://app.roboflow.com/ 并开始注释您的图像,如下图所示。

在这个例子中,我绕道所有游泳者来预测他们在未来图像中的位置。
为了获得良好的结果,请裁剪所有游泳者并将边界框放置在对象后面以正确包围它。

ROBOFLOW - train & test with python

您已经可以使用公共 roboflow 数据集进行此检查 https://universe.roboflow.com/

二、培训

在训练阶段,你可以直接使用 roboflow,但到了第三次你就需要付费了,这就是为什么我向你展示如何使用笔记本电脑进行操作。

第一步是导入数据集。为此,您可以导入 Roboflow 库。

pip install roboflow
登录后复制

要创建模型,您需要使用 YOLO 算法,您可以使用 ultralytics 库导入该算法。

pip install ultralytics
登录后复制

在我的脚本中,我使用以下命令:

py train.py api-key project-workspace project-name project-version nb-epoch size_model
登录后复制

您必须获得:

  • 访问密钥
  • 工作空间
  • roboflow 项目名称
  • 项目数据集版本
  • 训练模型的纪元数
  • 神经网络大小

最初,脚本会下载 yolov8-obb​​.pt,这是带有锻炼前数据的默认 yolo 权重,以方便训练。

import sys
import os
import random
from roboflow import Roboflow
from ultralytics import YOLO
import yaml
import time


class Main:
    rf: Roboflow
    project: object
    dataset: object
    model: object
    results: object
    model_size: str

    def __init__(self):
        self.model_size = sys.argv[6]
        self.import_dataset()
        self.train()

    def import_dataset(self):
        self.rf = Roboflow(api_key=sys.argv[1])
        self.project = self.rf.workspace(sys.argv[2]).project(sys.argv[3])
        self.dataset = self.project.version(sys.argv[4]).download("yolov8-obb")

        with open(f'{self.dataset.location}/data.yaml', 'r') as file:
            data = yaml.safe_load(file)

        data['path'] = self.dataset.location

        with open(f'{self.dataset.location}/data.yaml', 'w') as file:
            yaml.dump(data, file, sort_keys=False)

    def train(self):
        list_of_models = ["n", "s", "m", "l", "x"]
        if self.model_size != "ALL" and self.model_size in list_of_models:

            self.model = YOLO(f"yolov8{self.model_size}-obb.pt")

            self.results = self.model.train(data=f"{self.dataset.location}/"
                                                 f"yolov8-obb.yaml",
                                            epochs=int(sys.argv[5]), imgsz=640)



        elif self.model_size == "ALL":
            for model_size in list_of_models:
                self.model = YOLO(f"yolov8{model_size}.pt")

                self.results = self.model.train(data=f"{self.dataset.location}"
                                                     f"/yolov8-obb.yaml",
                                                epochs=int(sys.argv[5]),
                                                imgsz=640)



        else:
            print("Invalid model size")



if __name__ == '__main__':
    Main()
登录后复制

三、显示

训练完模型后,得到文件best.py和last.py,它们对应的是权重。

使用ultralytics库,您还可以导入YOLO并加载您的体重,然后加载您的测试视频。
在此示例中,我使用跟踪功能来获取每个游泳者的 ID。

import cv2
from ultralytics import YOLO
import sys


def main():
    cap = cv2.VideoCapture(sys.argv[1])

    model = YOLO(sys.argv[2])

    while True:
        ret, frame = cap.read()
        results = model.track(frame, persist=True)
        res_plotted = results[0].plot()
        cv2.imshow("frame", res_plotted)

        if cv2.waitKey(1) == 27:
            break

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()
登录后复制

为了分析预测,您可以获取模型json,如下所示。

 results = model.track(frame, persist=True)
 results_json = json.loads(results[0].tojson())
登录后复制

以上是ROBOFLOW - 使用 python 进行训练和测试的详细内容。更多信息请关注PHP中文网其他相关文章!

来源:dev.to
本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责声明 Sitemap
PHP中文网:公益在线PHP培训,帮助PHP学习者快速成长!