curiosidade.probers.utils

Utility function related to probing model creation.

Module Contents

Functions

get_probing_model_feedforward(hidden_layer_dims: Sequence[int], include_batch_norm: bool = False, dropout: float = 0.0) → ProbingModelType

Get a ProbingModelFeedforward architecture.

get_probing_model_for_sequences(hidden_layer_dims: Sequence[int], pooling_strategy: Literal[max, mean, keep_single_index] = 'max', pooling_axis: int = 1, embedding_index_to_keep: int = 0, include_batch_norm: bool = False, dropout: float = 0.0) → ProbingModelType

Get a ProbingModelForSequences architecture.

curiosidade.probers.utils.get_probing_model_feedforward(hidden_layer_dims: Sequence[int], include_batch_norm: bool = False, dropout: float = 0.0) ProbingModelType

Get a ProbingModelFeedforward architecture.

Parameters
  • hidden_layer_dims (t.Sequence[int]) – Number of units in each hidden layer.

  • include_batch_norm (bool, default=False) – If True, include Batch Normalization between Linear and ReLU modules.

  • dropout (float, default=0.0) – Amount of dropout per layer.

Returns

architecture – Callable that generates the corresponding probing model.

Return type

t.Callable[[int, …], ProbingModelFeedforward]

curiosidade.probers.utils.get_probing_model_for_sequences(hidden_layer_dims: Sequence[int], pooling_strategy: Literal[max, mean, keep_single_index] = 'max', pooling_axis: int = 1, embedding_index_to_keep: int = 0, include_batch_norm: bool = False, dropout: float = 0.0) ProbingModelType

Get a ProbingModelForSequences architecture.

This probing model architecture handles variable-length inputs, by applying a pooling function in a variable-length axis, therefore transforming the inputs to fixed-length representations.

Parameters
  • hidden_layer_dims (t.Sequence[int]) – Number of units in each hidden layer.

  • pooling_strategy ({'max', 'mean', 'keep_single_index'}, default='max') –

    Pooling strategy, to transform variable-length tensors into fixed-length tensors.

    • max: select element-wise maxima on elements along pooling_axis;

    • mean: compute element-wise averages on elements along pooling_axis; or

    • keep_single_index: keep a single vector along pooling_axis at the index embedding_index_to_keep (see argument below), and discard everything else.

  • pooling_axis (int, default=1) – Axis to apply pooling.

  • embedding_index_to_keep (int, default=0) – Embedding index to keep when pooling_strategy=’keep_single_index’. This argument has no effect for other pooling strategies.

  • include_batch_norm (bool, default=False) – If True, include Batch Normalization between Linear and ReLU modules.

  • dropout (float, default=0.0) – Dropout probability per hidden layer.

Returns

architecture – Callable that generates the corresponding probing model.

Return type

t.Callable[[int, …], ProbingModelForSequences]