当前位置: 首页 > 编程日记 > 正文

Keras 最新《面向小数据集构建图像分类模型》

本文地址:http://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

本文作者:Francois Chollet

  • 按照官方的文章实现过程有一些坑,彻底理解代码细节实现,理解keras的api具体使用方法
  • 也有很多人翻译这篇文章,但是有些没有具体实现细节
  • 另外keres开发者自己有本书的jupyter:Companion Jupyter notebooks for the book "Deep Learning with Python"
  • 另外我自己实验三收敛的准确率并没有0.94+,可以参考前面这本书上的实现
  • 文章一共有三个实验:
      1. 第一个实验使用自定义的神经网络对数据集进行训练,三层卷积加两层全连接,训练并验证网络的准确率;
      2. 第二个实验使用VGG16网络对数据进行训练,为了适应自定义的数据集,将VGG16网络的全连接层去掉,作者称之为 “Feature extraction”, 再在上面添加自己实现的全连接层,然后训练并验证网络准确性;
      3. 第三个实验称为 “fine-tune” ,利用第二个实验的实验模型和weight,重新训练VGG16的最后一个卷积层和自定义的全连接层,然后验证网络准确性;
  • 实验二的代码:
'''This script goes along the blog post
"Building powerful image classification models using very little data"
from blog.keras.io.
It uses data that can be downloaded at:
https://www.kaggle.com/c/dogs-vs-cats/data
In our setup, we:
- created a data/ folder
- created train/ and validation/ subfolders inside data/
- created cats/ and dogs/ subfolders inside train/ and validation/
- put the cat pictures index 0-999 in data/train/cats
- put the cat pictures index 1000-1400 in data/validation/cats
- put the dogs pictures index 12500-13499 in data/train/dogs
- put the dog pictures index 13500-13900 in data/validation/dogs
So that we have 1000 training examples for each class, and 400 validation examples for each class.
In summary, this is our directory structure:
```
data/train/dogs/dog001.jpgdog002.jpg...cats/cat001.jpgcat002.jpg...validation/dogs/dog001.jpgdog002.jpg...cats/cat001.jpgcat002.jpg...
```
'''
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense
from keras import applications# dimensions of our images.
img_width, img_height = 150, 150top_model_weights_path = 'bottleneck_fc_model.h5'data_root = 'M:/dataset/dog_cat/'
train_data_dir =data_root+ 'data/train'
validation_data_dir = data_root+'data/validation'
nb_train_samples = 2000
nb_validation_samples = 800
epochs = 50
batch_size = 16def save_bottlebeck_features():datagen = ImageDataGenerator(rescale=1. / 255)# build the VGG16 networkmodel = applications.VGG16(include_top=False, weights='imagenet')generator = datagen.flow_from_directory(train_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode=None,shuffle=False)bottleneck_features_train = model.predict_generator(generator, nb_train_samples // batch_size) #####2000//batch_size!!!!!!!!!!np.save('bottleneck_features_train.npy',bottleneck_features_train)generator = datagen.flow_from_directory(validation_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode=None,shuffle=False)bottleneck_features_validation = model.predict_generator(generator, nb_validation_samples // batch_size)np.save('bottleneck_features_validation.npy',bottleneck_features_validation)def train_top_model():train_data = np.load('bottleneck_features_train.npy')train_labels = np.array([0] * int(nb_train_samples / 2) + [1] * int(nb_train_samples / 2))validation_data = np.load('bottleneck_features_validation.npy')validation_labels = np.array([0] * int(nb_validation_samples / 2) + [1] * int(nb_validation_samples / 2))model = Sequential()model.add(Flatten(input_shape=train_data.shape[1:]))model.add(Dense(256, activation='relu'))model.add(Dropout(0.5))model.add(Dense(1, activation='sigmoid'))model.compile(optimizer='rmsprop',loss='binary_crossentropy', metrics=['accuracy'])model.fit(train_data, train_labels,epochs=epochs,batch_size=batch_size,validation_data=(validation_data, validation_labels))model.save_weights(top_model_weights_path)#save_bottlebeck_features()
train_top_model()
  • 实验三代码,自己添加了一些api使用方法,也是以后可以参考的:
'''This script goes along the blog post
"Building powerful image classification models using very little data"
from blog.keras.io.
It uses data that can be downloaded at:
https://www.kaggle.com/c/dogs-vs-cats/data
In our setup, we:
- created a data/ folder
- created train/ and validation/ subfolders inside data/
- created cats/ and dogs/ subfolders inside train/ and validation/
- put the cat pictures index 0-999 in data/train/cats
- put the cat pictures index 1000-1400 in data/validation/cats
- put the dogs pictures index 12500-13499 in data/train/dogs
- put the dog pictures index 13500-13900 in data/validation/dogs
So that we have 1000 training examples for each class, and 400 validation examples for each class.
In summary, this is our directory structure:
```
data/train/dogs/dog001.jpgdog002.jpg...cats/cat001.jpgcat002.jpg...validation/dogs/dog001.jpgdog002.jpg...cats/cat001.jpgcat002.jpg...
```
'''

# thanks sove bug @http://blog.csdn.net/aggresss/article/details/78588135from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense
from keras.models import Model
from keras.regularizers import l2# path to the model weights files.
weights_path = '../keras/examples/vgg16_weights.h5'
top_model_weights_path = 'bottleneck_fc_model.h5'
# dimensions of our images.
img_width, img_height = 150, 150data_root = 'M:/dataset/dog_cat/'
train_data_dir =data_root+ 'data/train'
validation_data_dir = data_root+'data/validation'nb_train_samples = 2000
nb_validation_samples = 800
epochs = 50
batch_size = 16# build the VGG16 network
base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150,150,3)) # train 指定训练大小
print('Model loaded.')# build a classifier model to put on top of the convolutional model
top_model = Sequential()
top_model.add(Flatten(input_shape=base_model.output_shape[1:]))  # base_model.output_shape[1:])
top_model.add(Dense(256, activation='relu',kernel_regularizer=l2(0.001),))
top_model.add(Dropout(0.8))
top_model.add(Dense(1, activation='sigmoid'))# note that it is necessary to start with a fully-trained
# classifier, including the top classifier,
# in order to successfully do fine-tuning
top_model.load_weights(top_model_weights_path)# add the model on top of the convolutional base
# model.add(top_model) # bugmodel = Model(inputs=base_model.input, outputs=top_model(base_model.output))# set the first 25 layers (up to the last conv block)
# to non-trainable (weights will not be updated)
for layer in model.layers[:15]:  # :25 buglayer.trainable = False# compile the model with a SGD/momentum optimizer
# and a very slow learning rate.
model.compile(loss='binary_crossentropy',optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),metrics=['accuracy'])# prepare data augmentation configuration
train_datagen = ImageDataGenerator(rescale=1. / 255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)test_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(train_data_dir,target_size=(img_height, img_width),batch_size=batch_size,class_mode='binary')validation_generator = test_datagen.flow_from_directory(validation_data_dir,target_size=(img_height, img_width),batch_size=batch_size,class_mode='binary')model.summary() # prints a summary representation of your model.
# let's visualize layer names and layer indices to see how many layers
# we should freeze:
for i, layer in enumerate(base_model.layers):print(i, layer.name)from keras.utils import plot_model
plot_model(model, to_file='model.png')from keras.callbacks import History
from keras.callbacks import ModelCheckpoint
import keras
history = History()
model_checkpoint = ModelCheckpoint('temp_model.hdf5', monitor='loss', save_best_only=True)
tb_cb = keras.callbacks.TensorBoard(log_dir='log', write_images=1, histogram_freq=0)
# 设置log的存储位置,将网络权值以图片格式保持在tensorboard中显示,设置每一个周期计算一次网络的
# 权值,每层输出值的分布直方图
callbacks = [history,model_checkpoint,tb_cb]
# model.fit()# fine-tune the model
history=model.fit_generator(train_generator,steps_per_epoch=nb_train_samples // batch_size,epochs=epochs,callbacks=callbacks,validation_data=validation_generator,validation_steps=nb_validation_samples // batch_size,verbose = 2)model.save('fine_tune_model.h5')
model.save_weights('fine_tune_model_weight')
print(history.history)from matplotlib import pyplot as plt
history=history
plt.plot()
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()import  numpy as np
accy=history.history['acc']
np_accy=np.array(accy)
np.savetxt('save_acc.txt',np_accy)
  • result
Model loaded.
Found 2000 images belonging to 2 classes.
Found 800 images belonging to 2 classes.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 150, 150, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 150, 150, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 150, 150, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 75, 75, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 75, 75, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 75, 75, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 37, 37, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 37, 37, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 37, 37, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 37, 37, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 18, 18, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 18, 18, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 18, 18, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 18, 18, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 9, 9, 512)         0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 9, 9, 512)         2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 4, 4, 512)         0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 2097665   
=================================================================
Total params: 16,812,353
Trainable params: 9,177,089
Non-trainable params: 7,635,264
_________________________________________________________________
0 input_1
1 block1_conv1
2 block1_conv2
3 block1_pool
4 block2_conv1
5 block2_conv2
6 block2_pool
7 block3_conv1
8 block3_conv2
9 block3_conv3
10 block3_pool
11 block4_conv1
12 block4_conv2
13 block4_conv3
14 block4_pool
15 block5_conv1
16 block5_conv2
17 block5_conv3
18 block5_pool
Backend TkAgg is interactive backend. Turning interactive mode on.
  • reference: 第八期 使用 Keras 训练神经网络 《显卡就是开发板》

相关文章:

Python处理XML文件

用代码记录下: import xml.dom.minidomtry:f open(filename)dom xml.dom.minidom.parseString(f.read()) finally:f.close()if dom ! None:root dom.documentElementfor element in root.getElementsByTagName("bean"):for prop in element.getElement…

李彦宏首次公布24字百度愿景,要做最懂用户的公司

编辑 | 一一 出品 | AI科技大本营 近日,李彦宏发布内部信并首次公布了 24 字百度愿景:成为最懂用户,并能帮助人们成长的全球顶级高科技公司。李彦宏表示,“这 24 个字将上承新使命、下展公司“夯实移动基础、决胜 AI 时代”的整体…

HP c3000/c7000 blade switch GBE2c 初始配置

端口概述 Port 1-16 是内联刀片的downlink口 Port 17-18 是switch互联用,默认是disable的 Port 19是给Blade On-board Administrator用 Port 20-24 是uplink口 连到交换机的Console口 To access the switch locally: 1. Connect the switch DB-9 serial connector, …

Travis CI : 最小的分布式系统(一)

(本文翻译自http://www.paperplanes.de/2013/10/18/the-smallest-distributed-system.html,由金斌_jinbin 翻译) Travis CI一开始仅仅是个想法,在当时甚至还有些理想化。在这个项目启动之前,开源社区还没有一个可用的持续集成系统。 随着作为…

Windows 7时代即将终结!

作者 | 屠敏 转载自CSDN(ID:CSDNnews) 2009 年诞生的 Windows 7 终究没能超过 Windows XP 13 岁的寿命。 2015 年 1 月 14 日,微软宣布结束对 Windows 7 操作系统的第一阶梯主流支持,同时为了给用户过渡升级的时间,…

什么是DWDM

DWDM是Dense Wavelength Division Multiplexing(密集波分复用)的缩写,这是一项用来在现有的光纤骨干网上提高带宽的激光技术。更确切地说,该技术是在一根指定的光纤中,多路复用单个光纤载波的紧密光谱间距,…

MySQL · myrocks · MyRocks之memtable切换与刷盘

概述 MyRocks的memtable默认是skiplist,其大小和个数分别由参数write_buffer_size和max_write_buffer_number控制。数据写入时先写入active memtable, 当active memtable写满时,active memtable会转化为immutable memtable. immutable memtable数据是不会…

URLRewriter在ASP.NET配置文件中的用法

<?xml version"1.0"?><configuration><configSections><sectionGroup name"system.web.extensions" type"System.Web.Configuration.SystemWebExtensionsSectionGroup, System.Web.Extensions, Version1.0.61025.0, Culturene…

Travis CI : 最小的分布式系统(二)

大约1年之前&#xff0c;我们发现当时的架构有些不合理了。尤其是Hub&#xff0c;它上面承担了太多的任务。Hub要接收新的处理请求&#xff0c;处理并推动构建日志&#xff0c;它要同步用户信息到Github&#xff0c;它要通知用户构建是否成功。它跟一大群外部API打交道&#xf…

百度开设「黄埔学院」,革新者来

1 月 19 日&#xff0c;百度宣布成立「黄埔学院」&#xff0c;开展深度学习架构师培养计划。并借鉴了黄埔军校大门对联的横批「革命者来」&#xff0c;将口号设置为「革新者来」。 首先&#xff0c;为什么叫「黄埔学院」&#xff1f; 2012 年初&#xff0c;百度开始进行深度学…

Linux-find命令应用举例-按时间筛选和删除文件

find参数说明&#xff1a; find有很多参数是以动作首字母时间的方式用于按访问、改变、更新时间来筛选文件。 动作表达&#xff1a; a(last accessed) 最近一次访问时间 c(last changed) 最近一次改变时间 m(last modified) 最近一次修改时间注意此上的c和m的区别&#xff0c;…

2007年11月网络工程师考试试题

● 若某计算机系统由两个部件串联构成&#xff0c;其中一个部件的失效率为710&#xff0d;6/小时。若不考虑其他因素的影响&#xff0c;并要求计算机系统的平均故障间隔时间为105小时&#xff0c;则另一个部件的失效率应为 &#xff08;1&#xff09; /小时。 &#xff08;1&am…

Travis CI : 最小的分布式系统(三)

日志的作用有两个&#xff1a;当构建日志的数据块通过消息队列进来时&#xff0c;更新数据库对应行&#xff0c;然后推送它到Pusher用于实时的用户界面更新。 日志块以流的形式在同一个时间从不同的进程中进来&#xff0c;然后被一个进程处理。这个进程每秒最高可处理100个消息…

windows 上rsync客户端使用方法

阅读目录 1.1 获取 windows上实现rsync的软件&#xff08;cwRsync&#xff09;1.2 cwrsync的使用方法1.3 cwrsync的使用回到顶部1.1 获取 windows上实现rsync的软件&#xff08;cwRsync&#xff09; cwRsync是Windows 客户端GUI的一个包含Rsync的包装。您可以使用cwRsync快速远…

机器学习开源项目Top10

整理 | Jane 出品 | AI科技大本营 【导语】又到了我们固定给大家推荐开源项目的时间。本期将为大家推荐 10 个机器学习开源项目&#xff0c;统计了过去一个月中 250 个机器学习开源项目&#xff0c;并从中选取了本期的 Top10。平均 1483 Stars。不知道是不是有你喜欢的欢迎大…

大规模服务设计部署经验谈

本文中提出的最佳实践&#xff0c;来自于作者多年大规模服务设计和部署的经验&#xff0c;为设计、开发对运营友好的服务提供了一系列良好的解决方案。■ 文&#xff0f;James Hamilton 译&#xff0f;赖翥翔1 引言 本文就设计和开发运营友好的服务的话题进行总结…

修改mysql数据库默认编码为utf8

查看当前字符编码&#xff1a; mysql < show variables like character%;为了解决中文乱码问题&#xff0c;修改mysql默认数据库编码为utf8&#xff0c;修改/etc/my.cnf [client]default-character-setutf8[mysql]default-character-setutf8[mysqld]character-set-serverutf…

CSDN创始人蒋涛:AI定义的开发者时代

1月18日&#xff0c;由中国软件行业协会主办的2019中国软件产业年会&#xff0c;在国家会议中心举行。CSDN创始人&董事长蒋涛&#xff0c;在大会上发表了题为《AI定义的开发者时代》的主题演讲。 以下为演讲实录&#xff1a; 我们在PC互联网时代就建立了中国软件开发者社区…

numpy.ndarray的赋值操作

matzeros((3,4)) #生成一个3行4列全部元素为0的矩阵mat[1,:]111 #从第1行第0列开始&#xff0c;一直到最后一列&#xff0c;赋值为1&#xff0c;效果与mat[1,0:3]相同&#xff0c;前置0可以省略&#xff0c;最后的列数可以省略输出&#xff1a;[[ 0. 0. 0. 0.][ 111. 111. 111.…

travis-ci如何配置android

travis-ci如何配置android travis-ci 关于android部分&#xff1a;http://docs.travis-ci.com/user/languages/android/ language: android android:components:- build-tools-19.1.0 # BuildTools version- android-19 # SDK version- sy…

你的微笑,拂过我的心海

??初冬的午后&#xff0c;阳光&#xff0c;懒懒地伸展着腰肢,企业形象宣传片 &#xff0c;偶然从窗帘漏进几缕稀少的斜影。南方的冬天总是姗姗来迟&#xff0c;让人认为&#xff0c;那只不过是秋天残存的脚步&#xff0c;还没来得及捉住&#xff0c;它却已从你的眉间静静地溜…

重读Youtube深度学习推荐系统论文,字字珠玑,惊为神文

作者简介&#xff0c;王喆&#xff0c;硅谷高级机器学习工程师。 本文转载自知乎专栏 https://zhuanlan.zhihu.com/p/52169807 这里是王喆的机器学习笔记&#xff0c;每隔一到两周我会站在算法工程师的角度讲解一些计算广告、推荐系统相关的文章。选择文章必须满足一下三个条件…

Struts的select两种遍历方法

转载于:https://blog.51cto.com/9695005/2050390

nginx http 服务器搭建

下载nginx源码&#xff1a;http://nginx.org/en/download.html 安装&#xff1a; wget http://nginx.org/download/nginx-1.9.3.tar.gz cd nginx-1.9.3 ./configure --prefix/usr/local/nginx发现一个问题&#xff1a; checking for PCRE library ... not found checking for P…

加速电子化报销费控服务,易快报完成1500万美元B轮融资

2019年1月21日&#xff0c;报销费控领头羊品牌——易快报对外宣布完成1500万美元B轮系列融资&#xff0c;本轮融资由美元基金曼图资本领投&#xff0c;DCM、明势、银杏谷等投资机构跟投&#xff0c;冲盈资本为本轮独家财务顾问。国内报销费控SaaS行业是个潜力巨大的增量市场&am…

[转]C# 2.0新特性与C# 3.5新特性

C# 2.0新特性与C# 3.5新特性 一、C# 2.0 新特性&#xff1a; 1、泛型List<MyObject> obj_listnew List();obj_list.Add(new MyObject()); 2、部分类(partial)namespace xxx{public partial class Class1{private string _s1;public string S1{get { return _s1; }set { _…

你需要了解的load和initialize

NSObject类有两种初始化方式load和initialize load (void)load; 复制代码对于加入运行期系统的类及分类&#xff0c;必定会调用此方法&#xff0c;且仅调用一次。 iOS会在应用程序启动的时候调用load方法&#xff0c;在main函数之前调用 执行子类的load方法前&#xff0c;会…

iOS11、iPhone X、Xcode9 适配指南

2017.09.23 不断完善中。。。 2017.10.02 新增 iPhone X 适配官方中文文档 更新iOS11后&#xff0c;发现有些地方需要做适配&#xff0c;整理后按照优先级分为以下三类&#xff1a; 单纯升级iOS11后造成的变化&#xff1b;Xcode9 打包后造成的变化&#xff1b;iPhoneX的适配一、…

Grape和Sinatra结合使用

Grape && Sinatra Grape(https://github.com/intridea/grape) is a REST-like API micro-framework for Ruby Sinatra(http://www.sinatrarb.com/intro.html) is a DSL for quickly creating web applications in Ruby 可见&#xff0c;Grape适合构建纯Api系统&#xf…

公告三大“罪状”,无人驾驶公司Roadstar联合创始人被罢免

&#xff08;从左至右依次是为周光、佟显乔、衡量&#xff09; 整理 | Jane 出品 | AI科技大本营 1 月 21 日&#xff0c;因技术造假等违规行为&#xff0c;国内自动驾驶创业公司 Roadstar &#xff08;深圳星行科技有限公司&#xff09;官方宣布&#xff0c;罢免联合创始人周…