到上一步为止,基本的步骤已经完成了 ,但既然开头说了,要达到能实际使用的程度,那就需要再进行一些完善了。先得试试实时性怎么样,因此采用摄像头实时获取图片,实时辨别来试试效果。这部分和机器学习关系不大,因此略过详细过程,核心方法为- (void)captureOutput:(AVCaptureOutput *)output didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer fromConnection:(AVCaptureConnection *)connection
,
使用的queue要注意下,更新的时候也要注意切回主线程
在这个回调里,先把图片摆正,这样能提高准确率。[connection setVideoOrientation:AVCaptureVideoOrientationPortrait];
然后根据sampleBuffer
获取图片即可。
实际效果下图:
图不动的话,戳这里。 这里在末尾也演示了前面提到的other
这个分类,其实这个分类存在的目的,就是为了增加分类器的健壮性,可参加这个issue。如果只训练两个分类A和B,那么分类器A和B的概率相加为1,假设新物体非常不像A,那么有可能显示的B的概率为1,造成误判,other
这个分类的意义就在于摊平这里的概率,当然对于other
里训练的图片选择,感觉是个大学问,目前我只是随意的放了些非目标分类的图片。
做了那么多的步骤,直接看图就知道效果了:
图不动的话,戳这里。
总体来说,目标物对的稍微准点,95+%的识别率还是有的,超过了我的预期,可应用到实际中。
前文提到过,导出的包为94.2MB,这对于一个iOS App来说,实在是有点太大了。贴心的Apple当然也给了解决方法,那就是替换卷积神经网络CNN,CNN的主要目的是 提取图片的特征值。替换的地方在turi_train.py
的第三步:
# 3. 生成模型
model = tc.image_classifier.create(train_data, target='name')
这里还有一个参数,model
,改成
model = tc.image_classifier.create(train_data, target='name',model='squeezenet_v1.1')
也就是把model
的CNN指定为squeezenet_v1.1
(默认的为resnet-50)。当然这里还可以设置其他的参数,比如最大迭代次数等。这样导出的mlmodel
一下子就变成了5MB左右,小了非常的多!当然,这也牺牲了一定的精度。具体对比,Apple已经列了对比:
而Apple官网提供的“从1000种类别的对象中检测出图像中的主体”的训练集当中,从大到小依次为
至于如何在精度和包大小取舍就看自己的选择了。
一个包离线打在项目里,既更新不了,又导致每个用户的包都变大了,这显然不是一个好的实践。Apple提供了一个新的API,+ (NSURL *)compileModelAtURL:(NSURL *)modelURL error:(NSError * _Nullable *)error;
,使用方法也很简单,下载数据,放到沙盒里,然后compile
即可。需要注意的是,这个方法较为耗时,不要放在主线程。
这样包大小的问题也算一定程度上解决了。
读文章最怕介绍的都是各种优点的文章,显然,作为这么个工具,还是需要提出我在这整个过程中遇到的问题:
Turi Create
这个工具能做到的远不止图像分类,还有目标追踪,推荐系统,相似图片,文字识别等等。其中目标跟踪跟本实践较为接近,这个可以继续叠加训练数据的维度来实现。需要增加的工作为,需要标记每一张训练图的目标物方框坐标,数据格式为:
[{'coordinates': {'height': 104, 'width': 110, 'x': 115, 'y': 216}, 'label': 'ball'}, {'coordinates': {'height': 106, 'width': 110, 'x': 188, 'y': 254}, 'label': 'ball'}, {'coordinates': {'height': 164, 'width': 131, 'x': 374, 'y': 169}, 'label': 'cup'}]
其他步骤跟前文提到的基本一致。具体可以大家自己尝试。介绍在官网github上。
本文中通过TuriCreate
生成的数据为mlmodel
,仅供iOS使用,可通过开源工具MMdnn
来转换为Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx这些模型,从而供其他方来使用。
@张云龙:
也可以只用训练素材图片,然后用 tensorflow-for-poets-2 来训练,得到 retrained_graph.pb 和 retrained_labels.txt 集成到Android中。
执行脚本
本文是一篇应用型的文章,基本没有介绍真正的机器学习的知识。这部分还是很有必要深入了解下的,这两个感觉介绍的不错,可推荐:
本文来自网易实践者社区,经作者陈蒙奇授权发布。