当前位置: 代码迷 >> python >> numpy 的 argpartition 如何在文档示例中工作?
  详细解决方案

numpy 的 argpartition 如何在文档示例中工作?

热度:92   发布时间:2023-07-14 09:48:58.0

我想了解 numpy 的 argpartition 函数。 我已使的示例尽可能基本。

import numpy as np

x = np.array([3, 4, 2, 1])
print("x: ", x)

a=np.argpartition(x, 3)
print("a: ", a)

print("x[a]:", x[a])

这是输出...

('x: ', array([3, 4, 2, 1]))
('a: ', array([2, 3, 0, 1]))
('x[a]:', array([2, 1, 3, 4]))

在 a=np.argpartition(x, 3) 行中,第 k 个元素不是最后一个元素(数字 1)吗? 如果是数字 1,当 x 排序时 1 不应该成为第一个元素(元素 0)吗?

在 x[a] 中,为什么 2 是 1 的“前面”的第一个元素?

我缺少什么基本的东西?

对所做的更完整的答案是在的文档中,其中一个说:

创建一个数组的副本,其元素以这样一种方式重新排列,即第 k 个位置的元素的值位于它在排序数组中的位置。 所有小于第 k 个元素的元素都移动到这个元素之前,所有相等或更大的元素都移动到它后面。 两个分区中元素的顺序未定义。

因此,对于输入数组3, 4, 2, 1 ,排序后的数组将为1, 2, 3, 4

np.partition([3, 4, 2, 1], 3)将在第三个(即最后一个)元素中具有正确的值(即与排序数组相同)。 第三个元素的正确值是4

让我为k所有值展示这一点,以使其清楚:

  • np.partition([3, 4, 2, 1], 0) - [ 1 , 4, 2, 3]
  • np.partition([3, 4, 2, 1], 1) - [1, 2 , 4, 3]
  • np.partition([3, 4, 2, 1], 2) - [1, 2, 3 , 4]
  • np.partition([3, 4, 2, 1], 3) - [2, 1, 3, 4 ]

换句话说:结果的第 k 个元素与排序数组的第 k 个元素相同。 k 之前的所有元素都小于或等于该元素。 它之后的所有元素都大于或等于它。

argpartition发生同样的情况,除了argpartition返回索引,然后可以将其用于形成相同的结果。

我记得我也很难弄清楚,也许文档写得不好,但这就是它的意思

当你做a=np.argpartition(x, 3)然后 x 以这样的方式排序,只有在第 k 个索引处的元素才会被排序(在我们的例子中 k=3)

因此,当您运行这段代码时,您基本上是在询问排序数组中第三个索引的值是多少。 因此输出是('x[a]:', array([2, 1, 3, 4]))其中只有元素 3 被排序。

正如文档所建议的,所有小于第 k 个元素的数字都在它之前(没有特定的顺序),因此你在 1 之前得到 2,因为它没有特定的顺序。

我希望这可以澄清它,如果您仍然感到困惑,请随时发表评论:)

与@Imtinan 类似,我为此苦苦挣扎。 我发现将函数分解为 arg 和分区很有用。

取以下数组:

array = np.array([9, 2, 7, 4, 6, 3, 8, 1, 5])

the corresponding indices are: [0,1,2,3,4,5,6,7,8] where 8th index = 5 and 0th = 9

如果我们执行np.partition(array, k=5) ,代码将采用第 5 个元素(不是索引),然后将其放入一个新数组中。 然后它将把那些元素 < 5th 元素放在它之前,然后把那些 > 5th 元素放在之后,就像这样:

pseudo output: [lower value elements, 5th element, higher value elements]

如果我们计算这个,我们会得到:

array([3, 5, 1, 4, 2, 6, 8, 7, 9])

这是有道理的,因为原始数组中的第 5 个元素 = 6,[1,2,3,4,5] 都小于 6,而 [7,8,9] 大于 6。注意元素是无序的.

然后np.argpartition()的 arg 部分更进一步,将元素交换为它们在原始数组中的相应索引。 所以如果我们这样做:

np.argpartition(array, 5)我们将得到:

array([5, 8, 7, 3, 1, 4, 6, 2, 0])

从上面看,原始数组有这个结构 [index=value] [0=9, 1=2, 2=7, 3=4, 4=6, 5=3, 6=8, 7=1, 8=5 ]

您可以将索引的值映射到输出,并且满足条件:

argpartition() = partition() ,像这样:

[索引形式] array([5, 8, 7, 3, 1, 4, 6, 2, 0]) 变为

[3, 5, 1, 4, 2, 6, 8, 7, 9]

这与np.partition(array)的输出相同,

array([3, 5, 1, 4, 2, 6, 8, 7, 9])

希望这是有道理的,这是我了解函数的 arg 部分的唯一方法。