本文分了两部分,第一部分讲讲axis参数的理解,第二部分从np.stack函数的应用来反向理解该函数究竟干了些什么。
<>1.np.stack()中axis参数的深入理解
看了一下大家关于np.stack()的理解,我感觉自己还是一知半解,有点蒙。自己又想把这个函数搞明白
,于是花了一点时间终于对这个函数有了自己的理解,决定把自己的想法写下来与大家分享,希望对大家有帮助。
stack为堆叠的意思,这个函数主要有两个参数,第一个是需要堆叠的多个数组,采用列表的形式输入,例如:np.stack([arrays1,array2,array3],axis=0)。第二个参数是axis,这个参数表示从哪一个维度进行堆叠以及堆叠的内容,
这个维度是相对于堆叠的数组来说的。整个函数的输出为一个新数组。
下面以三维输出为例来谈谈我的理解。
先定义3个3*4的数组用来进行堆叠,注意进行堆叠的数组形式必须一致,在这里全为3×4:
a=np.array([i for i in range(12)]).reshape(3,4) b=np.array([i for i in range(12
,24)]).reshape(3,4) c=np.array([i for i in range(24,36)]).reshape(3,4)
其结果为:
a= [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]] b= [[12 13 14 15] [16 17 18 19] [20 21
22 23]] c= [[24 25 26 27] [28 29 30 31] [32 33 34 35]]
来看看axis=0时,它是如何进行堆叠的:
new_array=np.stack([a,b,c],axis=0) print(new_array) #输出结果为: [[[ 0 1 2 3] [ 4 5
6 7] [ 8 9 10 11]] [[12 13 14 15] [16 17 18 19] [20 21 22 23]] [[24 25 26 27] [
28 29 30 31] [32 33 34 35]]]
axis为0,表示它堆叠方向为第0维,堆叠的内容为数组第0维的数据。前面说了第0维是相对于堆叠的数组而言的,而这里数组的第0维其实就是整个3×4的数组(其中第1维为行,第2维为某一行中的一个值,这里有一个层层深入的感觉),所以就是以整个3×4的数组为堆叠内容在第0维上进行堆叠,等到的结果就是一个3×3×4的新数组。再通俗一点,就是将a,b,c分别作为堆叠内容进行堆叠得到3×3×4的输出。
再来看看axis=1的时候:
new_array=np.stack([a,b,c],axis=1) print(new_array) #输出结果为: [[[ 0 1 2 3] [12 13
14 15] [24 25 26 27]] [[ 4 5 6 7] [16 17 18 19] [28 29 30 31]] [[ 8 9 10 11] [20
21 22 23] [32 33 34 35]]]
和刚才的解释一样,axis为1表示堆叠的方向为3×4数组的第1维(行),堆叠内容也为3×4数组的第1维的数据。而3×4的数组的第1维就是它的行,以数组a为例,它的堆叠数据分别是[0
1 2 3],[ 4 5 6 7],[ 8 9 10 11]。所以a,b,c三个数组在第1维上堆叠后的结果就是上面的输出结果。
当axis=2时,表示堆叠内容是3×4数组的第二维的数据(数组中某一行的某个值),堆叠方向为第二维(也就是先补全某一行)。对于数组a,第2维的第一个值为0,b为12,c为24,所这三个值组成堆叠后的新数组的一行,以此类推,最终可以得到下面的输出结果。
new_array=np.stack([a,b,c],axis=2) print(new_array) [[[ 0 12 24] [ 1 13 25] [ 2
14 26] [ 3 15 27]] [[ 4 16 28] [ 5 17 29] [ 6 18 30] [ 7 19 31]] [[ 8 20 32] [ 9
21 33] [10 22 34] [11 23 35]]]
<>2.np.stack()的应用
现在假设我有一个3×3×4的数组a:
[[[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]] [[12 13 14 15] [16 17 18 19] [20 21 22 23
]] [[24 25 26 27] [28 29 30 31] [32 33 34 35]]]
另外假设:
x=a[0,:,:] y=a[1,:,:] z=a[2,:,:]
如果我们想用x,y,z来重新拼成数组a,就可以利用stack函数来实现:
s=np.stack([x,y,z],axis=0) print(s) #结果等于数组a [[[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10
11]] [[12 13 14 15] [16 17 18 19] [20 21 22 23]] [[24 25 26 27] [28 29 30 31] [
32 33 34 35]]]
再假设:
x=a[:,:,0] y=a[:,:,1] z=a[:,:,2] r=a[:,:,3]
同样的,这里的数组a被分成了四个低维的元数组x,y,z,r。如果想用这四个元数组拼成原数组a我们可以这样做:
s=np.stack([x,y,z,r],axis=2) print(s) [[[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]] [[
12 13 14 15] [16 17 18 19] [20 21 22 23]] [[24 25 26 27] [28 29 30 31] [32 33 34
35]]]
希望这个思路能有助于大家更好理解。