1.OneHotCategorical
torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)
根据给定的概率probs, 创建一个 one-hot 的类别分布.
m = OneHotCategorical(torch.tensor([ 0.1, 0.0, 0.9, 0.0 ]))
m.sample() # equal probability of 0, 1, 2, 3
#tensor([0., 0., 1., 0.])
参考:
- pytorch distributions