Skip to content Skip to sidebar Skip to footer

Use Of Dim=0/1 In Pytorch And Nn.softmax?

When using nn.softmax(), we use dim=1 or 0. Here dim=0 should mean row according to intuition but seems it means along the column. Is this true? >>> x = torch.tensor([[1,2

Solution 1:

Indeed, in the 2D case: row refers to axis=0, while column refers to axis=1.

The dim option specifies along which dimension the softmax is apply, i.e. summing back on that same axis will lead to 1s:

>>> x = torch.arange(1, 7, dtype=float).reshape(2,3)
tensor([[1., 2., 3.],
        [4., 5., 6.]], dtype=torch.float64)

On axis=0:

>>> F.softmax(x, dim=0).sum(0)
tensor([1.0000, 1.0000, 1.0000], dtype=torch.float64)

On axis=1:

>>> F.softmax(x, dim=1).sum(1)
>>> tensor([1.0000, 1.0000], dtype=torch.float64)

This is the expected behavior for torch.nn.functional.softmax

[...] Parameters:

  • dim (int) – A dimension along which Softmax will be computed (so every slice along dim will sum to 1).

Post a Comment for "Use Of Dim=0/1 In Pytorch And Nn.softmax?"