Constructor.
909 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
910 """
911 Inputs:
912 input_dim - Dimensionality of the input
913 num_heads - Number of heads to use in the attention block
914 dim_feedforward - Dimensionality of the hidden layer in the MLP
915 dropout - Dropout probability to use in the dropout layers
916 """
917 super().__init__()
918
919
920 self.which_linear = which_linear
921
922 self.self_attn = MultiheadAttention(
923 input_dim, input_dim, num_heads, which_linear
924 )
925
926
927 self.linear_net = nn.Sequential(
928 self.which_linear(input_dim, dim_feedforward),
929 nn.Dropout(dropout),
930 nn.ReLU(inplace=True),
931 self.which_linear(dim_feedforward, input_dim),
932 )
933
934
935
936 self.norm1 = nn.LayerNorm(input_dim)
937
938 self.norm2 = nn.LayerNorm(input_dim)
939
940 self.dropout = nn.Dropout(dropout)
941