Source code for stressnet.net.stressnet

__all__ = ['get_model', 'build_stressnet']

import tensorflow as tf
from spektral.layers import GlobalAttnSumPool
from tensorflow.keras.layers import Activation, Dense, Input, LayerNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import L2

from .custom_layers import ConcatBroadcast, NodeRescale, StressConv

DEFAULT_DEVICE = '/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'


[docs] def get_model(weights_path, points_per_edge: int = 9, disjoint_mode: bool = False, device: str | None = None, **net_kwargs ) -> tf.keras.Model: """Build StressNET and optionally load weights on the chosen device. Parameters ---------- weights_path Path to an ``.h5`` weights file (may be ``None`` only for random init / debugging). points_per_edge Number of sampled points per edge (must match training / checkpoint). disjoint_mode If ``True``, build the disjoint-batched variant (Spektral disjoint mode). device TensorFlow device string; defaults to GPU when available. **net_kwargs Forwarded to :func:`build_stressnet`. Returns ------- tensorflow.keras.Model Compiled Keras model. """ tf.keras.backend.clear_session() with tf.device(device or DEFAULT_DEVICE): model = build_stressnet( load_weights_path=weights_path, edge_n_vertices=points_per_edge, disjoint_mode=disjoint_mode, **net_kwargs ) return model
[docs] def build_stressnet(load_weights_path: str | None = None, edge_n_vertices: int = 9, disjoint_mode: bool = False, **kwargs ) -> Model: """Construct the StressNET graph neural network architecture. Parameters ---------- load_weights_path Optional path to load trained weights after building. edge_n_vertices Edge discretization count (controls input edge feature dimension). disjoint_mode Whether to include disjoint-mode graph index input. **kwargs Architecture hyperparameters (layer sizes, regularization, ``fine_tune_layers``, etc.). Returns ------- tensorflow.keras.Model Uncompiled or compiled model per internal ``Model`` factory defaults. """ # = INPUT DIMENSIONS = F = kwargs.get('n_node_features', 6) # number of features per node E = 2 * 2 * (edge_n_vertices - 1) # number of points in edge features times 2 (x and y coordinates) # = MODEL HIPERPARAMETERS = NON_LINEARITY = kwargs.get('non_linearity', 'elu') USE_LAYER_NORM = kwargs.get('use_layer_norm', True) N_BLOCKS = kwargs.get('n_blocks', 3) EDGE_FEATURES_EXPAND = kwargs.get('edge_features_expand', 512) EDGE_FEATURES_SQUEEZE = kwargs.get('edge_features_squeeze', 128) X_DIM = kwargs.get('x_dim', (32, 48, 64, 96)) HIDDEN_DIM_FACTOR = kwargs.get('hidden_dim_factor', 2) HEAD_MLP_UNITS = kwargs.get('head_mlp_units', (256, 64)) EDGES_MLP_REG = L2(kwargs['edges_mlp_l2_reg']) if 'edges_mlp_l2_reg' in kwargs else L2(1e-6) CONV_REG = L2(kwargs['conv_l2_reg']) if 'conv_l2_reg' in kwargs else None POOL_ATTN_REG = L2(kwargs['global_pool_l2_reg']) if 'global_pool_l2_reg' in kwargs else None HEAD_MLP_REG = L2(kwargs['head_mlp_l2_reg']) if 'head_mlp_l2_reg' in kwargs else L2(1e-6) OUTPUT_LAYER_REG = L2(kwargs['reg_head_l2_reg']) if 'reg_head_l2_reg' in kwargs else None OUTPUT_NORM_MODE = kwargs.get('output_norm_mode', 'mean') FINE_TUNE_LAYERS = set(kwargs.get('fine_tune_layers', [])) assert (not FINE_TUNE_LAYERS) or load_weights_path, 'Freezing layers with randomly initialized weights ??' # Model inputs x0 = Input(shape=(F,), name='node_feats') a = Input((None,), sparse=True, name='adj_mat') e0 = Input(shape=(E,), name='edge_feats') gi = Input(shape=(), dtype=tf.int32, name='graph_indices') if disjoint_mode else None # obtain embeddings from edge geometries # expand e = Dense(EDGE_FEATURES_EXPAND, kernel_regularizer=EDGES_MLP_REG, use_bias=not USE_LAYER_NORM, name='edge_feats_fc_1')(e0) if USE_LAYER_NORM: e = LayerNormalization(name='edge_feats_ln_1')(e) e = Activation(NON_LINEARITY, name=f'edge_feats_{NON_LINEARITY}_1')(e) # squeeze e = Dense(EDGE_FEATURES_SQUEEZE, kernel_regularizer=EDGES_MLP_REG, use_bias=not USE_LAYER_NORM, name='edge_feats_fc_2')(e) if USE_LAYER_NORM: e = LayerNormalization(name='edge_feats_ln_2')(e) # graph convolutions x = x0 for i, dim in enumerate(X_DIM, start=1): for j in range(1, N_BLOCKS + 1): x = StressConv(channels=dim, use_layer_norm=USE_LAYER_NORM, activation=NON_LINEARITY, kernel_regularizer=CONV_REG, hidden_dim_factor=HIDDEN_DIM_FACTOR, name=f'graph_conv_{i}.{j}')([x, a, e]) x = Activation(NON_LINEARITY, name=f'graph_conv_out_{NON_LINEARITY}')(x) # concatenate the pooled representation of all nodes in the grpah to each individual representation pool_inputs = [x, gi] if disjoint_mode else x pool = GlobalAttnSumPool(attn_kernel_regularizer=POOL_ATTN_REG, name='global_pool')(pool_inputs) concat_inputs = [x, pool, gi] if disjoint_mode else [x, pool] x = ConcatBroadcast(name='concat_global_feats')(concat_inputs) # "head" MLP for i, neurons in enumerate(HEAD_MLP_UNITS, start=1): x = Dense(neurons, kernel_regularizer=HEAD_MLP_REG, use_bias=not USE_LAYER_NORM, name=f'reg_head_fc_{i}')(x) if USE_LAYER_NORM: x = LayerNormalization(name=f'reg_head_ln_{i}')(x) x = Activation(NON_LINEARITY, name=f'reg_head_{NON_LINEARITY}_{i}')(x) # linear regression x = Dense(1, name='reg_head_out', kernel_regularizer=OUTPUT_LAYER_REG)(x) # rescaling layer (normalize output to have mean equal to 1 and squeeze the last dimension) mean_norm_inputs = [x, gi] if disjoint_mode else x x = NodeRescale(agg_mode=OUTPUT_NORM_MODE, name=f'{OUTPUT_NORM_MODE}_norm')(mean_norm_inputs) # build model object model_inputs = [x0, a, e0, gi] if disjoint_mode else [x0, a, e0] model = Model(inputs=model_inputs, outputs=x, name='StressNetV0.2') # load weights if load_weights_path: model.load_weights(load_weights_path) if FINE_TUNE_LAYERS: available_layer_names = {layer.name for layer in model.layers} missing_layers = sorted(FINE_TUNE_LAYERS - available_layer_names) if missing_layers: missing = list(missing_layers) raise ValueError( f'Values in fine_tune_layers are not model layer names: {missing}. ' ) # freeze all layers except for the ones that we want to train for layer in model.layers: layer.trainable = (layer.name in FINE_TUNE_LAYERS) return model