当前位置: 代码迷 >> 综合 >> Pytorch probability distributions
  详细解决方案

Pytorch probability distributions

热度:52   发布时间:2023-12-19 03:20:17.0

1.OneHotCategorical

torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)

根据给定的概率probs, 创建一个 one-hot 的类别分布.

m = OneHotCategorical(torch.tensor([ 0.1, 0.0, 0.9, 0.0 ]))
m.sample()  # equal probability of 0, 1, 2, 3
#tensor([0., 0., 1., 0.])

参考:

  1. pytorch distributions