Use Of Dim=0/1 In Pytorch And Nn.softmax?
When using nn.softmax(), we use dim=1 or 0. Here dim=0 should mean row according to intuition but seems it means along the column. Is this true? >>> x = torch.tensor([[1,2
Solution 1:
Indeed, in the 2D case: row refers to axis=0
, while column refers to axis=1
.
The dim
option specifies along which dimension the softmax is apply, i.e. summing back on that same axis will lead to 1
s:
>>> x = torch.arange(1, 7, dtype=float).reshape(2,3)
tensor([[1., 2., 3.],
[4., 5., 6.]], dtype=torch.float64)
On axis=0
:
>>> F.softmax(x, dim=0).sum(0)
tensor([1.0000, 1.0000, 1.0000], dtype=torch.float64)
On axis=1
:
>>> F.softmax(x, dim=1).sum(1)
>>> tensor([1.0000, 1.0000], dtype=torch.float64)
This is the expected behavior for torch.nn.functional.softmax
[...] Parameters:
dim
(int) – A dimension along whichSoftmax
will be computed (so every slice along dim will sum to1
).
Post a Comment for "Use Of Dim=0/1 In Pytorch And Nn.softmax?"