本文是应用TensorFlow进行多任务&多标签分类的一个示例(TensorFlow Multi-Task Multi-Label Example)。
以Mnist数据集为燃料,以简单的MLP为引擎,代码走起。
本文代码见:https://github.com/xylcbd/blogs_code/blob/master/tensorflow-multitask-multilabel/main.py
任务概述
任务共有2个,分别如下;
- 第1个任务是识别数字(即10类分类问题)
- 第2个任务是识别属性(图片中的数字是否是奇数、图片中的数字是否大于5)
可以看到第2个任务是有多个属性,即多标签任务。
- [0, 0]代表[不是奇数, 不大于5]
- [1, 0]代表[是奇数, 不大于5]
- [0, 1]代表[不是奇数, 大于5]
- [1, 1]代表[是奇数, 大于5]
属性可以再增加,不是固定为2个属性。
模型概述
简单的解释一下。
这里模型采用的是一个3层MLP模型,前2层共用,最后一层为2个任务各自所有。
- 对于预测数字的任务,采用的是tf.losses.softmax_cross_entropy
- 对于预测属性的任务,采用的是tf.losses.sigmoid_cross_entropy
示例代码
1 | #coding: utf-8 |
模型效果
1 | [17:46:51.972] Epoch[10/10] Step[100/859] Train Minibatch Loss= 0.1084, Class Accuracy= 0.9791, Attrs Accuracy= 0.9848 |