如何理解Axis?

9次阅读

共计 2377 个字符,预计需要花费 6 分钟才能阅读完成。

前言
只有光头才能变强。
回顾前面:

从零开始学 TensorFlow【01- 搭建环境、HelloWorld 篇】
什么是 TensorFlow?
TensorFlow 读写数据

不知道大家最开始接触到 axis 的时候是怎么样的,反正我是挺难理解的.. 我们可以发现 TensorFlow 的很多 API 都有 axis 这个参数,如果我们对 axis 不了解,压根不知道 API 是怎么搞的。
一句话总结 axis:axis 可以方便我们将数据进行不同维度的处理。
一、理解 axis
如果你像我一样,发现 API 中有 axis 这个参数,但不知道是什么意思。可能就会搜搜 axis 到底代表的什么意思。于是可能会类似搜到下面的信息:
使用 0 值表示沿着每一列或行标签索引值向下执行方法 (axis= 0 代表往跨行) 使用 1 值表示沿着每一行或者列标签模向执行对应的方法(axis= 1 代表跨列)

但我们又知道,我们的数组不单单只有二维的,还有三维、四维等等。一旦维数超过二维,就无法用简单的行和列来表示了。
所以,可以用我下面的方式进行理解:

axis= 0 将最开外头的括号去除,看成一个整体,在这个整体上进行运算

axis= 1 将第二个括号去除,看成一个整体,在这个整体上进行运算
… 依次类推

话不多说,下面以例子说明~
1.1 二维数组之 concat
首先,我们来看个 concat 的例子,concat 第一个参数接收 val,第二个参数接收的是 axis

def learn_concat():

# 二维数组
t1 = tf.constant([[1, 2, 3], [4, 5, 6]])
t2 = tf.constant([[7, 8, 9], [10, 11, 12]])

with tf.Session() as sess:

# 二维数组针对 axis 为 0 和 1 的情况
print(sess.run(tf.concat([t1, t2], 0)))
print(sess.run(tf.concat([t1, t2], 1)))

ok,下面以图示的方式来说明。现在我们有两个数组,分别是 t1 和 t2:

首先,我们先看 axis= 0 的情况,也就是 tf.concat([t1, t2], 0)。从上面的描述,我们知道,先把第一个括号去除,然后将其子内容看成一个整体,在这个整体下进行想对应的运算(这里我们就是 concat)。

所以最终的结果是:

[
[1 2 3],
[4 5 6],
[7 8 9],
[10 11 12]
]

接着,我们再看 axis= 1 的情况,也就是 tf.concat([t1, t2], 1)。从上面的描述,我们知道,先把第二个括号去除,然后将其子内容看成一个整体,在这个整体下进行想对应的运算(这里我们就是 concat)。

所以最终的结果是:

[
[1, 2, 3, 7, 8, 9]
[4, 5, 6, 10, 11, 12]
]

1.2 三维数组之 concat
接下来我们看一下三维的情况

def learn_concat():

# 三维数组
t3 = tf.constant([[[1, 2], [2, 3]], [[4, 4], [5, 3]]])
t4 = tf.constant([[[7, 4], [8, 4]], [[2, 10], [15, 11]]])

with tf.Session() as sess:

# 三维数组针对 axis 为 0 和 1 和 -1 的情况
print(sess.run(tf.concat([t3, t4], 0)))
print(sess.run(tf.concat([t3, t4], 1)))
print(sess.run(tf.concat([t3, t4], -1)))

ok,下面也以图示的方式来说明。现在我们有两个数组,分别是 t3 和 t4:

首先,我们先看 axis= 0 的情况,也就是 tf.concat([t3, t4], 0)。从上面的描述,我们知道,先把第一个括号去除,然后将其子内容看成一个整体,在这个整体下进行想对应的运算(这里我们就是 concat)。

所以最终的结果是:

[
[
[1 2]
[2 3]
]
[
[4 4]
[5 3]
]
[
[7 4]
[8 4]
]
[
[2 10]
[15 11]
]
]

接着,我们再看 axis= 1 的情况,也就是 tf.concat([t3, t4], 1)。从上面的描述,我们知道,先把第二个括号去除,然后将其子内容看成一个整体,在这个整体下进行想对应的运算(这里我们就是 concat)。

所以最终的结果是:

[
[
[1 2]
[2 3]
[7 4]
[8 4]
]
[
[4 4]
[5 3]
[2 10]
[15 11]
]
]
最后,我们来看一下 axis=- 1 这种情况,在文档也有相关的介绍:
As in Python, the axis could also be negative numbers. Negative axis are interpreted as counting from the end of the rank, i.e.,axis + rank(values)-th dimension
所以,对于我们三维的数组而言,那 axis=- 1 实际上就是 axis=2,下面我们再来看一下这种情况:

最终的结果是:

[
[
[1 2 7 4]
[2 3 8 4]
]
[
[4 4 2 10]
[5 3 15 11]
]
]
除了 concat 以外,其实很多函数都用到了 axis 这个参数,再举个例子:

>>> item = np.array([[1,4,8],[2,3,5],[2,5,1],[1,10,7]])
>>> item
array([[1, 4, 8],
[2, 3, 5],
[2, 5, 1],
[1, 10, 7]])

>>> item.sum(axis = 1)
array([13, 10, 8, 18])

>>> item.sum(axis = 0)
array([6, 22, 21])

参考资料:

有关 axis/axes 的理解
https://zhuanlan.zhihu.com/p/25761406

最后
下一篇是 TensorBoard~
乐于输出干货的 Java 技术公众号:Java3y。公众号内有 200 多篇原创技术文章、海量视频资源、精美脑图,不妨来关注一下!

觉得我的文章写得不错,不妨点一下赞!

正文完
 0