基于RetinaNet和TensorFlow Object Detection API实现目标检测(附源码)

news/2024/4/28 14:38:00/文章来源:https://blog.csdn.net/liuqiker/article/details/130830921

文章目录

  • 一、RetinaNet原理
  • 二、RetinaNet实现
    • 1. tf.train.CheckPoint简介
    • 2. RetinaNet的TensorFlow源码

一、RetinaNet原理

在这里插入图片描述
待补充

二、RetinaNet实现

1. tf.train.CheckPoint简介

待补充

2. RetinaNet的TensorFlow源码

  Step 1:安装Tensorflow 2 Object Detection API及相关包

# 删除models文件夹下所有文件
!rm -rf ./models/
# 拷贝Tensorflow Model Garden
!git clone --depth 1 https://github.com/tensorflow/models/
# 编译Object Detection API protocol buffers
!cd models/research/ && protoc object_detection/protos/*.proto --python_out=.%%writefile models/research/setup.py
import os
from setuptools import find_packages
from setuptools import setupREQUIRED_PACKAGES = ['tf-models-official==2.8.0','tensorflow_io==0.24.0','numpy==1.21.5'
]setup(name='object_detection',version='0.1',install_requires=REQUIRED_PACKAGES,include_package_data=True,packages=([p for p in find_packages() if p.startswith('object_detection')] +find_packages(where=os.path.join('.', 'slim'))),package_dir={'datasets': os.path.join('slim', 'datasets'),'nets': os.path.join('slim', 'nets'),'preprocessing': os.path.join('slim', 'preprocessing'),'deployment': os.path.join('slim', 'deployment'),'scripts': os.path.join('slim', 'scripts'),},description='Tensorflow Object Detection Library',python_requires='>3.6',
)# Run the setup script you just wrote
!python -m pip install models/research

  Step 2:导入包

import matplotlib
import matplotlib.pyplot as pltimport os
import random
import io
import imageio
import glob
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, Javascript
from IPython.display import Image as IPyImageimport tensorflow as tffrom object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import colab_utils
from object_detection.builders import model_builder%matplotlib inline

  Step 3:图片加载&画图工具函数定义

def load_image_into_numpy_array(path):"""Load an image from file into a numpy array.Puts image into numpy array to feed into tensorflow graph.Note that by convention we put it into a numpy array with shape(height, width, channels), where channels=3 for RGB.Args:path: a file path.Returns:uint8 numpy array with shape (img_height, img_width, 3)"""img_data = tf.io.gfile.GFile(path, 'rb').read()image = Image.open(BytesIO(img_data))(im_width, im_height) = image.sizereturn np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)def plot_detections(image_np,boxes,classes,scores,category_index,figsize=(12, 16),image_name=None):"""Wrapper function to visualize detections.Args:image_np: uint8 numpy array with shape (img_height, img_width, 3)boxes: a numpy array of shape [N, 4]classes: a numpy array of shape [N]. Note that class indices are 1-based,and match the keys in the label map.scores: a numpy array of shape [N] or None.  If scores=None, thenthis function assumes that the boxes to be plotted are groundtruthboxes and plot all boxes as black with no classes or scores.category_index: a dict containing category dictionaries (each holdingcategory index `id` and category name `name`) keyed by category indices.figsize: size for the figure.image_name: a name for the image file."""image_np_with_annotations = image_np.copy()viz_utils.visualize_boxes_and_labels_on_image_array(image_np_with_annotations,boxes,classes,scores,category_index,use_normalized_coordinates=True,min_score_thresh=0.8)if image_name:plt.imsave(image_name, image_np_with_annotations)else:plt.imshow(image_np_with_annotations)

  Step 4:下载训练图片集(此处以training-zombie为例)

# download the images
!wget --no-check-certificate \https://storage.googleapis.com/tensorflow-3-public/datasets/training-zombie.zip \-O ./training-zombie.zipimport zipfile
# unzip to a local directory
local_zip = './training-zombie.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./training')
zip_ref.close()

  Step 5:切换训练图片的路径,初始化训练图片list,并展示样例

train_image_dir = './training'
train_image_name = 'training-zombie'# Load images and visualize
train_images_np = []
for i in range(1, 6):image_path = os.path.join(train_image_dir, train_image_name + str(i) + '.jpg')train_images_np.append(load_image_into_numpy_array(image_path))plt.rcParams['axes.grid'] = False
plt.rcParams['xtick.labelsize'] = False
plt.rcParams['ytick.labelsize'] = False
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['figure.figsize'] = [14, 7]for idx, train_image_np in enumerate(train_images_np):plt.subplot(2, 3, idx+1)  # 2, 3 -> 1, 5plt.imshow(train_image_np)
plt.show() # 样例展示

  样例效果如下图:
在这里插入图片描述
  Step 6:初始化边框位置(人为确定真实框线的坐标,用于训练)

gt_boxes = [np.array([[0.27333333, 0.41500586, 0.74333333, 0.57678781]], dtype=np.float32),np.array([[0.29833333, 0.45955451, 0.75666667, 0.61078546]], dtype=np.float32),np.array([[0.40833333, 0.18288394, 0.945, 0.34818288]], dtype=np.float32),np.array([[0.16166667, 0.61899179, 0.8, 0.91910903]], dtype=np.float32),np.array([[0.28833333, 0.12543962, 0.835, 0.35052755]], dtype=np.float32),
]

  Step 7:初始化待检测目标的label和分类,由于我们只检测一种物体,故分类为1

zombie_class_id = 1
num_classes = 1category_index = {zombie_class_id: {'id': zombie_class_id, 'name': 'zombie'}}

  Step 8:将训练数据转换为tensor(即TensorFlow可识别的数据格式)

label_id_offset = 1
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np) in zip(train_images_np, gt_boxes):train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(train_image_np, dtype=tf.float32), axis=0))gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))zero_indexed_groundtruth_classes = tf.convert_to_tensor(np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)gt_classes_one_hot_tensors.append(tf.one_hot(zero_indexed_groundtruth_classes, num_classes))
print('Done prepping data.')

  Step 9:展示准备好的训练tensor和边框(在数据的预处理过程中,要多观察数据是否正确)

dummy_scores = np.array([1.0], dtype=np.float32)  # give boxes a score of 100%plt.figure(figsize=(30, 15))
for idx in range(5):plt.subplot(2, 3, idx+1)plot_detections(train_images_np[idx],gt_boxes[idx],np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),dummy_scores, category_index)
plt.show()

  展示效果如下图:
在这里插入图片描述
  Step 10:下载Retinanet模型

!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint models/research/object_detection/test_data/

  Step 11:模型加载、修改(主要修改检测物体的类别数量)、weights初始化(通过假数据的预测初始化weights)

tf.keras.backend.clear_session()print('Building model and restoring weights for fine-tuning...', flush=True)
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'# Load pipeline config and build a detection model.
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(model_config=model_config, is_training=True)fake_box_predictor = tf.train.Checkpoint(_base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,_box_prediction_head=detection_model._box_predictor._box_prediction_head,)
fake_model = tf.train.Checkpoint(_feature_extractor=detection_model._feature_extractor,_box_predictor=fake_box_predictor)
ckpt = tf.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path)# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')

  Step 12:定义train_step和train_loop

tf.keras.backend.set_learning_phase(True)# 训练参数设置
batch_size = 4
learning_rate = 0.01
num_batches = 100# 从模型中选择需要fine tune的参数
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = ['WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead','WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):to_fine_tune.append(var)# train_step.
def get_model_train_step_function(model, optimizer, vars_to_fine_tune):"""Get a tf.function for training step."""@tf.functiondef train_step_fn(image_tensors,groundtruth_boxes_list,groundtruth_classes_list):"""A single training iteration.Args:image_tensors: A list of [1, height, width, 3] Tensor of type tf.float32.Note that the height and width can vary across images, as they arereshaped within this function to be 640x640.groundtruth_boxes_list: A list of Tensors of shape [N_i, 4] with typetf.float32 representing groundtruth boxes for each image in the batch.groundtruth_classes_list: A list of Tensors of shape [N_i, num_classes]with type tf.float32 representing groundtruth boxes for each image inthe batch.Returns:A scalar tensor representing the total loss for the input batch."""shapes = tf.constant(batch_size * [[640, 640, 3]], dtype=tf.int32)model.provide_groundtruth(groundtruth_boxes_list=groundtruth_boxes_list,groundtruth_classes_list=groundtruth_classes_list)with tf.GradientTape() as tape:preprocessed_images = tf.concat([detection_model.preprocess(image_tensor)[0]for image_tensor in image_tensors], axis=0)prediction_dict = model.predict(preprocessed_images, shapes)losses_dict = model.loss(prediction_dict, shapes)total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']gradients = tape.gradient(total_loss, vars_to_fine_tune)optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))return total_lossreturn train_step_fn# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
train_step_fn = get_model_train_step_function(detection_model, optimizer, to_fine_tune)print('Start fine-tuning!', flush=True)# 开始训练(即train_loop)
for idx in range(num_batches):# Grab keys for a random subset of examplesall_keys = list(range(len(train_images_np)))random.shuffle(all_keys)example_keys = all_keys[:batch_size]gt_boxes_list = [gt_box_tensors[key] for key in example_keys]gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]image_tensors = [train_image_tensors[key] for key in example_keys]# Training step (forward pass + backwards pass)total_loss = train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)if idx % 10 == 0:print('batch ' + str(idx) + ' of ' + str(num_batches)+ ', loss=' +  str(total_loss.numpy()), flush=True)print('Done fine-tuning!')

  Step 13:下载测试图片,用来测试上一步训练好的模型

# uncomment if you want to delete existing files
!rm zombie-walk-frames.zip
!rm -rf ./zombie-walk
!rm -rf ./results# download test images
!wget --no-check-certificate \https://storage.googleapis.com/tensorflow-3-public/datasets/zombie-walk-frames.zip \-O zombie-walk-frames.zip# unzip test images
local_zip = './zombie-walk-frames.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./results')
zip_ref.close()

  Step 14:测试image to numpy 转换

test_image_dir = './results/'
test_images_np = []# load images into a numpy array. this will take a few minutes to complete.
for i in range(0, 237):image_path = os.path.join(test_image_dir, 'zombie-walk' + "{0:04}".format(i) + '.jpg')print(image_path)test_images_np.append(np.expand_dims(load_image_into_numpy_array(image_path), axis=0))

  Step 15:目标检测函数定义

@tf.function
def detect(input_tensor):"""Run detection on an input image.Args:input_tensor: A [1, height, width, 3] Tensor of type tf.float32.Note that height and width can be anything since the image will beimmediately resized according to the needs of the model within thisfunction.Returns:A dict containing 3 Tensors (`detection_boxes`, `detection_classes`,and `detection_scores`)."""preprocessed_image, shapes = detection_model.preprocess(input_tensor)prediction_dict = detection_model.predict(preprocessed_image, shapes)detections = detection_model.postprocess(prediction_dict, shapes)return detections

  Step 16:调用目标检测函数,测试模型准确度

label_id_offset = 1
results = {'boxes': [], 'scores': []}i = 150
images_np = test_images_np
# input_tensor = train_image_tensors[i]
input_tensor = tf.convert_to_tensor(images_np[i], dtype=tf.float32)
detections = detect(input_tensor)detections['detection_boxes'][0].shape
detections['detection_classes'][0].shape
plot_detections(images_np[i][0],detections['detection_boxes'][0].numpy(),detections['detection_classes'][0].numpy().astype(np.uint32)+ label_id_offset,detections['detection_scores'][0].numpy(),category_index, figsize=(15, 20))

  测试结果如下图:
在这里插入图片描述
  由此可见,模型的检测效果符合预期。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_306508.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云原生之深入解析Docker容器退出码的含义和产生原因

一、前言 为什么我的容器没有运行?回答这个问题之前,需要知道 Docker 容器为什么退出?退出码会提示容器停止运行的情况?本文列出最常见的退出码,来回答两个重要问题:这些退出码是什么意思?导致该退出码的动作是什么?exit code:代表一个进程的返回码,通过系统调用 exi…

chatgpt赋能python:Python修改密码:一种安全可靠、快速高效的方式

Python 修改密码:一种安全可靠、快速高效的方式 在数字化时代,越来越多的信息被存储在计算机系统中,因此密码的保护变得尤为重要。人们需要保证他们的密码是安全可靠的,并定期更换密码。Python作为一种强大而且通用的编程语言&am…

iOS-最全的App上架教程

App上架教程 在上架App之前想要进行真机测试的同学,请查看《iOS- 最全的真机测试教程》,里面包含如何让多台电脑同时上架App和真机调试。 P12文件的使用详解 注意: 同样可以在Build Setting 的sign中设置证书,但是有点麻烦&…

生态伙伴 | 携手深圳科创学院,持续推动项目落地与成长

01 大赛介绍 中国硬件创新创客大赛始于2015年,由深圳华秋电子有限公司主办,至今已经成功举办八届,赛事范围覆盖华南、华东、华北三大地区,超10个省市区域。 大赛影响了超过45万工程师群体,吸引了35000多名硬创先锋报…

Linux文件系统、磁盘I/O是怎么工作的?

同CPU、内存一样,文件系统和磁盘I/O,也是Linux操作系统最核心的功能。磁盘为系统提供了最基本的持久化存储。文件系统则在磁盘基础上,提供了一个用来管理文件的树状结构。 目录: 一. 文件系统 1. 索引节点和目录项 2. 虚拟文件系…

抖音短视频APP的益与害都存在,今日详解其利弊

抖音是一款音乐创意短视频社交软件,是一个专注年轻人的15秒音乐短视频社区。这两年抖音太火了,不若与众身边的朋友百分之八十的朋友手机上都有这个软件,即使不拍也会抱着手机刷到停不下来。 首先,抖音其实给人们带来了许多乐趣和娱…

兼容性测试点和注意项,建议收藏

一:兼容性测试的概念:就是验证开发出来的程序在特定的运行环境中与特定的软件、硬件或数据相组合是否能正常运行、有无异常的测试过程。 二:兼容性测试的分类: (1)浏览器兼容性测试 指的是在浏览器上检查…

【CCF- CSP 202104-2 邻域均值 二维数组前缀和满分题解】

代码思路: 本题如果直接用暴力求解的话只能得70分。 运用到了二维数组的前缀和,难点是如何求出二维数组的前缀和并计算出领域所有元素的和。 注意计算平均数的时候要保证精度相同,所有都要化为double型,否则会出错。 首先&…

基于SpringBoot+Vue的闲一品交易平台设计与实现

博主介绍: 大家好,我是一名在Java圈混迹十余年的程序员,精通Java编程语言,同时也熟练掌握微信小程序、Python和Android等技术,能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架下…

160743-62-4,DMG PEG2000,1,2-二肉豆蔻酰-rac-甘油-3-甲氧基聚乙二醇2000

DMG PEG2000,DMG-mPEG2000,1,2-二肉豆蔻酰-rac-甘油-3-甲氧基聚乙二醇2000 Product structure: Product specifications: 1.CAS No:160743-62-4 2.Molecular formula: C34H66O 3.Molecular weight&#xff…

基于openfaas托管脚本的实践

作者 | 张曦 一、openfaas产品背景 在云服务架构发展之初,这个方向上的思路是使开发者不需要关心搭建和管理后端应用程序。这里并没有提及无服务器这个概念,而是指后端基础设施由第三方来托管,需要的基础架构组建均以服务的形式提供&#x…

Paragon NTFS2023最新mac免费实用工具磁盘工具

mac虽然系统稳定,但在使用过程中也有一些瑕疵,如当mac连接到ntfs格式移动磁盘时,可能会出现移动磁盘无法在mac被正常读写的状况。遇到移动磁盘无法正常读写的状况,我们可以在mac中使用磁盘工具,以使mac获得对ntfs格式移…

100种思维模型之全局观思维模型-67

全局观思维模型,一个教我们由点到线,由线到面,再由面到体,不断的放大格局去思考问题的思维模型。 01、何谓全局观思维模型 一、全局观思维 什么叫全局观? 世界上的所有东西,都是被规律作用者的&#xff0c…

23种设计模式之命令模式(Command Pattern)

前言:大家好,我是小威,24届毕业生,在一家满意的公司实习。本篇文章将23种设计模式中的命令模式,此篇文章为一天学习一个设计模式系列文章,后面会分享其他模式知识。 如果文章有什么需要改进的地方还请大佬不…

【三】设计模式~~~创建型模式~~~抽象工厂模式(Java)

【学习难度:★★★★☆,使用频率:★★★★★】 3.1. 模式动机 在工厂方法模式中具体工厂负责生产具体的产品,每一个具体工厂对应一种具体产品,工厂方法也具有唯一性,一般情况下,一个具体工厂中…

OJ练习第116题——二进制矩阵中的最短路径(BFS)

二进制矩阵中的最短路径 力扣链接:1091. 二进制矩阵中的最短路径 题目描述 给你一个 n x n 的二进制矩阵 grid 中,返回矩阵中最短 畅通路径 的长度。如果不存在这样的路径,返回 -1 。 二进制矩阵中的 畅通路径 是一条从 左上角 单元格&am…

ORB-LSAM2:ComputeKeyPointsOctTree()提取特征:maxY = iniY + hCell + 6 为怎么是+6而不是+3?

如标题所示&#xff0c;本博客主要讲述 void ORBextractor::ComputeKeyPointsOctTree(vector<vector<KeyPoint>> &allKeypoints){}函数中maxY iniY hCell 6 为怎么是6而不是3&#xff1f; 为了连续性&#xff0c;会介绍一下ComputeKeyPointsOctTree函数&a…

PMP课堂模拟题目及解析(第13期)

121. 项目经理、团队成员以及若干干系人共同参与一次风险研讨会。已经根据风险管理计划生成并提供一份风险报告。若要为各个项目风险进行优先级排序&#xff0c;现在必须执行哪一项分析&#xff1f; A. 定量风险分析 B. 根本原因分析 C. 偏差分析 D. 定性风险分析 122. …

软件系统三基座之一:权限管理

软件系统三基座包含&#xff1a;权限管理、组织架构、用户管理。 何为基座&#xff0c;即是有了这些基础&#xff0c;任一相关的“建筑”就能逐步搭建起来。 万丈高楼平地起 一、为什么要权限管理 权限管理&#xff0c;一般指根据系统设置的安全规则或者安全策略&#xff0c;…

报表控件FastReport使用指南-在Ubuntu LTS中创建PDF文档

FastReport 是功能齐全的报表控件&#xff0c;可以帮助开发者可以快速并高效地为.NET&#xff0c;VCL&#xff0c;COM&#xff0c;ActiveX应用程序添加报表支持&#xff0c;由于其独特的编程原则&#xff0c;现在已经成为了Delphi平台最优秀的报表控件&#xff0c;支持将编程开…