自定义神经网络时的注意事项

news/2024/5/5 10:27:23/文章来源:https://blog.csdn.net/qq_57390446/article/details/137617478

问题描述

`

通过继承tf.keras.Model自定义神经网络模型时遇到的一系列问题。

代码如下,

class STFT_ConV2D(tf.keras.Model):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.pre_layer = tf.keras.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(768, activation='relu')])self.add = tf.keras.layers.Add()self.output_dense = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):x, y = inputsx = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_x)(x)x = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_x)(x)x = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_x)(x)x = self.pre_layer(x)y = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_y)(y)y = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_y)(y)y = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_y)(y)y = self.pre_layer(y)output = self.add([x, y])output = self.output_dense(output)return output

产生的bug为,

  ValueError: Exception encountered when calling layer 'sequential' (type Sequential).Input 0 of layer "dense" is incompatible with the layer: expected axis -1 of input shape to have value 11368, but received input with shape (None, 210680)

x输入和y输入都使用了成员变量pre_layer,共享了pre_layer层,也就共享了pre_layer层的参数矩阵和结构。
由于x先经过三层卷积层后shape由原来的shape=(360, 256, 109, 1)变成了shape=(360, 203, 56, 1)
再经过pre_layer层里的Flatten时,除“ batchsize ”轴(axis=0)外,其余轴被铺平,输出shape=(360,11368)。接着处理y输入,经过三层卷积层后,shape由原来的shape=(360, 511, 513, 1)变成了shape=(360,458, 460, 1),之后执行到y = self.pre_layer(y)时,如果执行成功,则输出shape=(360,21068),此时与x的shape=(360,11368)维度冲突,从而产生异常。

要点归纳:

  1. 通过继承tf.keras.Model写神经网络模型时,每一个神经网络层只能被同一个输入占有。
  2. 所有tf.keras.layers下的层对象不能直接出现在call()方法中,必须以成员变量的形式在构造器中定义,然后在call()方法中通过self.成员变量的方式调用
  3. 卷积层tf.keras.layers.Conv2D()当神经网络第一层时,必须通过参数input_shape指定输入shape,该shape中不能包含“ batchsize ”轴,例如输入x的shape为(a, b, c, d),其中a代表样本数,b代表行像素,c代表列像素,d代表通道数。则应该指定input_shape=x.shape[1:],去除a所在轴,以免卷积层对该轴造成影响。

解决方案:

class STFT_ConV2D(tf.keras.Model):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.conV2d_x1 = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_x)self.conV2d_x2 = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_x)self.conV2d_x3 = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_x)self.conV2d_y1 = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_y)self.conV2d_y2 = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_y)self.conV2d_y3 = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_y)self.flatten_x = tf.keras.layers.Flatten()self.flatten_y = tf.keras.layers.Flatten()self.dense_x = tf.keras.layers.Dense(768, activation='relu')self.dense_y = tf.keras.layers.Dense(768, activation='relu')self.add = tf.keras.layers.Add()self.output_dense = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# x.shape = (360, 256, 109, 1) , y.shape = (360, 511, 513, 1)# inputs = (x, y)x, y = inputs  x = self.conV2d_x1(x) # (360, 249, 102, 3)x = self.conV2d_x2(x) # (360, 234, 87, 3)x = self.conV2d_x3(x) # (360, 203, 56, 1)x = self.flatten_x(x) # (360, 11368)x = self.dense_x(x)  # (360, 768)y = self.conV2d_y1(y)y = self.conV2d_y2(y)y = self.conV2d_y3(y)y = self.flatten_y(y)y = self.dense_y(y)output = self.add([x, y]) # (360, 768)output = self.output_dense(output)return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1045934.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Adobe After Effects 2024 v24.3 macOS 视频合成及特效制作软件 兼容 M1/M2/M3

Adobe After Effects 是一款适用于视频合成及特效制作软件,是制作动态影像设计不可或缺的辅助工具,是视频后期合成处理的专业非线性编辑软件。 macOS 12.0及以上版本可用 应用介绍 Adobe After Effects简称 AE 是一款适用于视频合成及特效制作软件,是制作动态影像设计不可或缺…

计算分数和-第12届蓝桥杯选拔赛Python真题精选

[导读]:超平老师的Scratch蓝桥杯真题解读系列在推出之后,受到了广大老师和家长的好评,非常感谢各位的认可和厚爱。作为回馈,超平老师计划推出《Python蓝桥杯真题解析100讲》,这是解读系列的第48讲。 计算分数和&#…

STM32 H7系列学习笔记

必备的API知识 第 1 步:系统上电复位,进入启动文件 startup_stm32h743xx.s,在这个文件里面执行复位中断服务程序。 在复位中断服务程序里面执行函数 SystemInit,在system_stm32h7xx.c 里面。*之后是调用编译器封装好的函数&…

Kubesphere 在 devops 部署项目的时候下载 maven 依赖卡住

Kubesphere 在 devops 部署项目的时候下载 maven 依赖卡住 我下载 下面这段 maven 依赖一直卡住&#xff1a; <build><plugins><plugin><groupId>org.jacoco</groupId><artifactId>jacoco-maven-plugin</artifactId><version>…

LeetCode 80—— 删除有序数组中的重复项 II

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 让 index指向删除重复元素后数组的新长度&#xff1b;让 st_idx 指向重复元素的起始位置&#xff0c;而 i 指向重复元素的结束位置&#xff0c;duplicate_num代表重复元素的个数&#xff1b;一段重复元素结束后&am…

Java Web-分层解耦

三层架构 当我们所有代码都写在一起时&#xff0c;代码的复用性差&#xff0c;并且难以维护。就像我们要修改一下服务端获取数据的方式&#xff0c;从文本文档获取改为到数据库中获取&#xff0c;就难以修改&#xff0c;而使用三层架构能很好的解决这个问题。 controller: 控…

PHP 中的 $2y$10,PHP 字符串加密函数 password_hash

PHP 用户密码加密函数 password_hash 自PHP5.5.0之后&#xff0c;新增加了密码散列算法函数(password_hash)&#xff0c;password_hash() 使用足够强度的单向散列算法创建密码的散列 Hash。 password_hash() 兼容 crypt()。 所以&#xff0c; crypt() 创建的密码散列也可用于 …

flask 访问404

当你的项目有自己的蓝图&#xff0c;有添加自己的前缀&#xff0c;也注册了蓝图。 在访问的路由那里也使用了自己的蓝图&#xff0c;如下图 然后你访问的地址也没问题&#xff0c;但是不管怎么样访问就是返回404&#xff0c;这个时候不要怀疑你上面的哪里配置错误&#xff0c;…

【网站项目】校园二手交易平台小程序

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

外包干了25天,技术倒退明显

先说情况&#xff0c;大专毕业&#xff0c;18年通过校招进入湖南某软件公司&#xff0c;干了接近6年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落&#xff01; 而我已经在一个企业干了四年的功能…

【mT5多语言翻译】之五——训练:中央日志、训练可视化、PEFT微调

请参考本系列目录&#xff1a;【mT5多语言翻译】之一——实战项目总览 [1] 模型训练与验证 在上一篇实战博客中&#xff0c;我们讲解了访问数据集中每个batch数据的方法。接下来我们介绍如何训练mT5模型进行多语言翻译微调。 首先加载模型&#xff0c;并把模型设置为训练状态&a…

网络安全指南:安全访问 Facebook 的技巧

在当今数字化时代&#xff0c;网络安全问题越来越受到人们的关注。尤其是在社交媒体平台上&#xff0c;如 Facebook 这样的巨头&#xff0c;用户的个人信息和隐私更容易受到威胁。为了保护自己的在线安全&#xff0c;我们需要采取一些措施来确保在使用 Facebook 时能够安全可靠…

C语言进阶|顺序表

✈顺序表的概念及结构 线性表&#xff08;linear list&#xff09;是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串.. 线性表在逻辑上是线性结构&#xff0c;也就说是连…

大话设计模式——23.备忘录模式(Memento Pattern)

简介 又称快照模式&#xff0c;在不破坏封装性的前提下&#xff0c;捕获一个对象的内部状态&#xff0c;并且该对象之外保存这个状态。这样以后就可将该对象恢复到原先保存的状态 UML图 应用场景 允许用户取消不确定或者错误的操作&#xff0c;能够恢复到原先的状态游戏存档、…

深度学习架构(CNN、RNN、GAN、Transformers、编码器-解码器架构)的友好介绍。

一、说明 本博客旨在对涉及卷积神经网络 &#xff08;CNN&#xff09;、递归神经网络 &#xff08;RNN&#xff09;、生成对抗网络 &#xff08;GAN&#xff09;、转换器和编码器-解码器架构的深度学习架构进行友好介绍。让我们开始吧&#xff01;&#xff01; 二、卷积神经网络…

【动手学深度学习】15_汉诺塔问题

注&#xff1a; 本系列仅为个人学习笔记&#xff0c;学习内容为《算法小讲堂》&#xff08;视频传送门&#xff09;&#xff0c;通俗易懂适合编程入门小白&#xff0c;需要具备python语言基础&#xff0c;本人小白&#xff0c;如内容有误感谢您的批评指正 汉诺塔&#xff08;To…

c/c++ |游戏后端开发之skynet

作者眼中的skynet 有一点要说明的是&#xff0c;云风至始也没有公开说skynet专门为游戏开发&#xff0c;换句话&#xff0c;skynet 引擎也可以用于web 开发 贴贴我的笔记 skynet 核心解决什么问题 愿景&#xff1a;游戏服务器能够充分利用多核优势&#xff0c;将不同的业务放在…

【随笔】Git 高级篇 -- 本地栈式提交 rebase | cherry-pick(十七)

&#x1f48c; 所属专栏&#xff1a;【Git】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大…

QT Creator概览

&#x1f40c;博主主页&#xff1a;&#x1f40c;​倔强的大蜗牛&#x1f40c;​ &#x1f4da;专栏分类&#xff1a;QT❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、Qt Creator 概览 ①&#xff1a;菜单栏 ②&#xff1a;模式选择 ③&#xff1a;构建套件选择器…

在图片上画出mask和pred

画出论文中《Variance-aware attention U-Net for multi-organ segmentation》的图1&#xff0c;也就是在原图上画出mask和pred的位置。 新建一个文件夹 然后运行代码&#xff1a; import cv2 import os from os.path import splitext####第一次&#xff1a;把GT&#xff08…