µ±Ç°Î»Ö㺠´úÂëÃÔ >> ×ÛºÏ >> PyTorchѧϰϵͳ֮ scatter() º¯ÊýÏê½â one hot ±àÂë
  Ïêϸ½â¾ö·½°¸

PyTorchѧϰϵͳ֮ scatter() º¯ÊýÏê½â one hot ±àÂë

Èȶȣº69   ·¢²¼Ê±¼ä£º2024-01-19 11:15:50.0

torch.Tensor.scatter_

scatter() ºÍ scatter_() µÄ×÷ÓÃÊÇÒ»ÑùµÄ£¬Ö»²»¹ý scatter() ²»»áÖ±½ÓÐÞ¸ÄÔ­À´µÄ Tensor£¬¶ø scatter_() »á

torch.Tensor.scatter_()ÊÇtorch.gather()º¯ÊýµÄ·½Ïò·´Ïò²Ù×÷¡£Á½¸öº¯Êý¿ÉÒÔ¿´³ÉÒ»¶ÔÐֵܺ¯Êý¡£gatherÓÃÀ´½âÂëone hot£¬scatter_ÓÃÀ´±àÂëone hot¡£

PyTorch ÖУ¬Ò»°ãº¯Êý¼ÓÏ»®Ïß´ú±íÖ±½ÓÔÚÔ­À´µÄ Tensor ÉÏÐÞ¸Ä

scatter_(dimindexsrc) ¡ú Tensor

²ÎÊý£º

  • dim£ºÑØ×ÅÄĸöά¶È½øÐÐË÷Òý
  • index£ºÓÃÀ´ scatter µÄÔªËØË÷Òý
  • src£ºÓÃÀ´ scatter µÄÔ´ÔªËØ£¬¿ÉÒÔÊÇÒ»¸ö±êÁ¿»òÒ»¸öÕÅÁ¿

Õâ¸ö scatter  ¿ÉÒÔÀí½â³É·ÅÖÃÔªËØ»òÕßÐÞ¸ÄÔªËØ

¼òµ¥Ëµ¾ÍÊÇͨ¹ýÒ»¸öÕÅÁ¿ src  À´ÐÞ¸ÄÁíÒ»¸öÕÅÁ¿£¬ÄĸöÔªËØÐèÒªÐ޸ġ¢Óà src ÖеÄÄĸöÔªËØÀ´ÐÞ¸ÄÓÉ dim ºÍ index ¾ö¶¨

¹Ù·½Îĵµ¸ø³öÁË 3άÕÅÁ¿ µÄ¾ßÌå²Ù×÷˵Ã÷£¬ÈçÏÂËùʾ

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
x = torch.rand(2, 5)#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
#        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
#        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

¾ßÌåµØ˵£¬ÎÒÃÇµÄ index ÊÇ torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])£¬Ò»¸ö¶þάÕÅÁ¿£¬ÏÂÃæÓÃͼ¼òµ¥ËµÃ÷

ÎÒÃÇÊÇ 2ά ÕÅÁ¿£¬Ò»¿ªÊ¼½øÐÐ self[index[0][0]][0]self[index[0][0]][0]£¬ÆäÖÐ index[0][0]index[0][0] µÄÖµÊÇ0£¬ËùÒÔÖ´ÐÐ self[0][0]=x[0][0]=0.1940self[0][0]=x[0][0]=0.1940 

self[index[i][j]][j]=src[i][j]

 

ÔÙ±ÈÈçself[index[1][0]][0]self[index[1][0]][0]£¬ÆäÖÐ index[1][0]index[1][0] µÄÖµÊÇ2£¬ËùÒÔÖ´ÐÐ self[2][0]=x[1][0]=0.2078self[2][0]=x[1][0]=0.2078 

¼ÆËã¹ý³Ì£ºindex[0,0]=0¡úself[0,0]¡úsrc[0,0] =0.1940

index[0,1]=1¡úself[1,1]¡úsrc[0,1] =0.3340

index[0,2]=2¡úself[2,2]¡úsrc[0,2] =0.8184

 

example£º

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)#tensor([[7., 7., 7., 7., 7.],
#        [0., 7., 0., 7., 0.],
#        [7., 0., 7., 0., 7.]]

¼ÆËã¹ý³Ì£ºindex[0,0]=0¡úself[0,0]¡úsrc[0,0] =7

index[0,1]=1¡úself[1,1]¡úsrc[0,1] =7

index[0,2]=2¡úself[2,2]¡úsrc[0,2] =7

scatter() Ò»°ã¿ÉÒÔÓÃÀ´¶Ô±êÇ©½øÐÐ one-hot ±àÂ룬Õâ¾ÍÊÇÒ»¸öµäÐ͵ÄÓñêÁ¿À´ÐÞ¸ÄÕÅÁ¿µÄÒ»¸öÀý×Ó

ÓÃÓÚ²úÉúone hot±àÂëµÄÏòÁ¿

example£º

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
#        [0],
#        [3],
#        [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
#        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
indices = torch.tensor(list(range(5))).view(5,1)  
indices
result = torch.zeros(5, 5)
result.scatter_(1, indices, 1)        

 

µ±Ã»ÓÐsrcֵʱ£¬ÔòËùÓÐÓÃÓÚÌî³äµÄÖµ¾ùΪvalueÖµ¡£

ÐèҪעÒâµÄʱºò£¬Õâ¸öʱºòindex.shape[dim]±ØÐëÓëresult.shape[dim]ÏàµÈ£¬·ñÔò»á±¨´í¡£

result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 3, 1, 2],[2, 1, 3, 1, 4]])
result.scatter_(1, indices, value=1)        

Êä³öΪ

tensor([[1., 1., 1., 0., 0.],[1., 1., 1., 1., 0.],[0., 1., 1., 1., 1.]])

²Î¿¼×ÊÁÏ

https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_

https://www.cnblogs.com/dogecheng/p/11938009.html

  Ïà¹Ø½â¾ö·½°¸