Constructor.
937 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
938 """
939 Inputs:
940 input_dim - Dimensionality of the input
941 num_heads - Number of heads to use in the attention block
942 dim_feedforward - Dimensionality of the hidden layer in the MLP
943 dropout - Dropout probability to use in the dropout layers
944 """
945 super().__init__()
946
947
948 self.which_linear = which_linear
949
950 self.self_attn = MultiheadAttention(
951 input_dim, input_dim, num_heads, which_linear
952 )
953
954
955 self.linear_net = nn.Sequential(
956 self.which_linear(input_dim, dim_feedforward),
957 nn.Dropout(dropout),
958 nn.ReLU(inplace=True),
959 self.which_linear(dim_feedforward, input_dim),
960 )
961
962
963
964 self.norm1 = nn.LayerNorm(input_dim)
965
966 self.norm2 = nn.LayerNorm(input_dim)
967
968 self.dropout = nn.Dropout(dropout)
969