14@keras.saving.register_keras_serializable(package="MyLayers")
16 """Concatenate the 3D input tensors and their 2D masks along the axis=1 dimension."""
20 Expect the input to be list of 3D tensors
22 return keras.ops.concatenate(inputs, axis=1)
26 Expect the mask to be list of 2D mask tensors
30 return keras.ops.concatenate(mask, axis=1)
39def create_mlp(hidden_units, dropout_rate, activation, normalization_layer, name=None):
41 for units
in hidden_units:
42 mlp_layers.append(normalization_layer())
43 mlp_layers.append(keras.layers.Dense(units, activation=activation))
44 mlp_layers.append(keras.layers.Dropout(dropout_rate))
46 return keras.Sequential(mlp_layers, name=name)
49def get_tflat_model(parameters, number_of_features):
51 Configure tflat model from parameters
53 clip_value = parameters.get(
"clip_value")
54 mask_value = parameters.get(
"mask_value")
55 num_trk = parameters.get(
"num_trk")
56 num_trk_features = parameters.get(
"num_trk_features")
57 num_ecl = parameters.get(
"num_ecl")
58 num_ecl_features = parameters.get(
"num_ecl_features")
59 num_roe = parameters.get(
"num_roe")
60 num_roe_features = parameters.get(
"num_roe_features")
61 num_transformer_blocks = parameters.get(
"num_transformer_blocks")
62 num_heads = parameters.get(
"num_heads")
63 embedding_dims = parameters.get(
"embedding_dims")
64 mlp_hidden_units_factors = parameters.get(
"mlp_hidden_units_factors")
65 dropout_rate = parameters.get(
"dropout_rate")
69 ecl_start = num_trk*num_trk_features
70 roe_start = ecl_start + num_ecl*num_ecl_features
73 inputs = keras.layers.Input((number_of_features,))
76 raw_features = keras.ops.nan_to_num(inputs, nan=mask_value)
79 raw_features = keras.ops.clip(raw_features, x_min=-clip_value, x_max=clip_value)
84 raw_trk_features = raw_features[:, trk_start:trk_start+num_trk*num_trk_features]
87 reshaped_trk_features = keras.layers.Reshape((num_trk, num_trk_features))(raw_trk_features)
90 masked_trk_features = keras.layers.Masking(mask_value=mask_value)(reshaped_trk_features)
93 normed_trk_features = keras.layers.BatchNormalization()(masked_trk_features)
96 encoded_trk_features = keras.layers.Dense(
98 activation=keras.activations.selu,
99 name=
"Embedding_trk_dense_1")(normed_trk_features)
100 encoded_trk_features = keras.layers.Dropout(dropout_rate, name=
"Embedding_trk_dropout_1")(encoded_trk_features)
101 encoded_trk_features = keras.layers.BatchNormalization(name=
"Embedding_trk_batchnorm")(encoded_trk_features)
102 encoded_trk_features = keras.layers.Dense(
103 units=embedding_dims,
104 activation=keras.activations.selu,
105 name=
"Embedding_trk_dense_2")(encoded_trk_features)
106 encoded_trk_features = keras.layers.Dropout(dropout_rate, name=
"Embedding_trk_dropout_2")(encoded_trk_features)
111 raw_ecl_features = raw_features[:, ecl_start:ecl_start+num_ecl*num_ecl_features]
114 reshaped_ecl_features = keras.layers.Reshape((num_ecl, num_ecl_features))(raw_ecl_features)
117 masked_ecl_features = keras.layers.Masking(mask_value=mask_value)(reshaped_ecl_features)
120 normed_ecl_features = keras.layers.BatchNormalization()(masked_ecl_features)
123 encoded_ecl_features = keras.layers.Dense(
124 units=embedding_dims,
125 activation=keras.activations.selu,
126 name=
"Embedding_ecl_dense_1")(normed_ecl_features)
127 encoded_ecl_features = keras.layers.Dropout(dropout_rate, name=
"Embedding_ecl_dropout_1")(encoded_ecl_features)
128 encoded_ecl_features = keras.layers.BatchNormalization(name=
"Embedding_ecl_batchnorm")(encoded_ecl_features)
129 encoded_ecl_features = keras.layers.Dense(
130 units=embedding_dims,
131 activation=keras.activations.selu,
132 name=
"Embedding_ecl_dense_2")(encoded_ecl_features)
133 encoded_ecl_features = keras.layers.Dropout(dropout_rate, name=
"Embedding_ecl_dropout_2")(encoded_ecl_features)
138 raw_roe_features = raw_features[:, roe_start:roe_start+num_roe*num_roe_features]
141 reshaped_roe_features = keras.layers.Reshape((num_roe, num_roe_features))(raw_roe_features)
144 masked_roe_features = keras.layers.Masking(mask_value=mask_value)(reshaped_roe_features)
147 normed_roe_features = keras.layers.BatchNormalization()(masked_roe_features)
150 encoded_roe_features = keras.layers.Dense(
151 units=embedding_dims,
152 activation=keras.activations.selu,
153 name=
"Embedding_roe_dense_1")(normed_roe_features)
154 encoded_roe_features = keras.layers.Dropout(dropout_rate, name=
"Embedding_roe_dropout_1")(encoded_roe_features)
155 encoded_roe_features = keras.layers.BatchNormalization(name=
"Embedding_roe_batchnorm")(encoded_roe_features)
156 encoded_roe_features = keras.layers.Dense(
157 units=embedding_dims,
158 activation=keras.activations.selu,
159 name=
"Embedding_roe_dense_2")(encoded_roe_features)
160 encoded_roe_features = keras.layers.Dropout(dropout_rate, name=
"Embedding_roe_dropout_2")(encoded_roe_features)
163 encoded_features =
MyConcatenate()([encoded_trk_features, encoded_ecl_features, encoded_roe_features])
166 for block_idx
in range(num_transformer_blocks):
168 attention_output = keras.layers.MultiHeadAttention(
170 key_dim=embedding_dims,
171 dropout=dropout_rate,
172 name=f
"multihead_attention_{block_idx}",
173 )(encoded_features, encoded_features)
175 x = keras.layers.Add(name=f
"skip_connection1_{block_idx}")(
176 [attention_output, encoded_features]
179 x = keras.layers.LayerNormalization(name=f
"layer_norm1_{block_idx}", epsilon=1e-6)(x)
181 feedforward_output = keras.layers.Dense(units=3*embedding_dims, activation=
'relu',
182 name=f
"feedforward_{block_idx}_dense_1")(x)
183 feedforward_output = keras.layers.Dense(units=embedding_dims, name=f
"feedforward_{block_idx}_dense_2")(feedforward_output)
184 feedforward_output = keras.layers.Dropout(dropout_rate, name=f
"feedforward_{block_idx}_dropout")(feedforward_output)
186 x = keras.layers.Add(name=f
"skip_connection2_{block_idx}")([feedforward_output, x])
188 encoded_features = keras.layers.LayerNormalization(
189 name=f
"layer_norm2_{block_idx}", epsilon=1e-6
193 features = keras.layers.GlobalAveragePooling1D()(encoded_features)
197 factor * features.shape[-1]
for factor
in mlp_hidden_units_factors
200 features = create_mlp(
201 hidden_units=mlp_hidden_units,
202 dropout_rate=dropout_rate,
203 activation=keras.activations.selu,
204 normalization_layer=keras.layers.BatchNormalization,
205 name=
"ClassifierMLP",
209 outputs = keras.layers.Dense(units=1, activation=
"sigmoid", name=
"sigmoid")(features)
210 model = keras.Model(inputs=inputs, outputs=outputs)
compute_mask(self, inputs, mask=None)