Source code for pytorch_extra_mhirano.nn.sum

import torch
import torch.nn as nn

__all__ = ["SumLayer"]


[docs]class SumLayer(nn.Module):
[docs] def __init__(self, dim: int) -> None: super(SumLayer, self).__init__() self.dim = dim
def forward(self, inputs: torch.Tensor) -> torch.Tensor: return torch.sum(inputs, self.dim)