curiosidade.probers.utils
Utility function related to probing model creation.
Module Contents
Functions
|
Get a |
|
Get a |
- curiosidade.probers.utils.get_probing_model_feedforward(hidden_layer_dims: Sequence[int], include_batch_norm: bool = False, dropout: float = 0.0) ProbingModelType
Get a
ProbingModelFeedforwardarchitecture.- Parameters
hidden_layer_dims (t.Sequence[int]) – Number of units in each hidden layer.
include_batch_norm (bool, default=False) – If True, include Batch Normalization between Linear and ReLU modules.
dropout (float, default=0.0) – Amount of dropout per layer.
- Returns
architecture – Callable that generates the corresponding probing model.
- Return type
t.Callable[[int, …], ProbingModelFeedforward]
- curiosidade.probers.utils.get_probing_model_for_sequences(hidden_layer_dims: Sequence[int], pooling_strategy: Literal[max, mean, keep_single_index] = 'max', pooling_axis: int = 1, embedding_index_to_keep: int = 0, include_batch_norm: bool = False, dropout: float = 0.0) ProbingModelType
Get a
ProbingModelForSequencesarchitecture.This probing model architecture handles variable-length inputs, by applying a pooling function in a variable-length axis, therefore transforming the inputs to fixed-length representations.
- Parameters
hidden_layer_dims (t.Sequence[int]) – Number of units in each hidden layer.
pooling_strategy ({'max', 'mean', 'keep_single_index'}, default='max') –
Pooling strategy, to transform variable-length tensors into fixed-length tensors.
max: select element-wise maxima on elements along pooling_axis;
mean: compute element-wise averages on elements along pooling_axis; or
keep_single_index: keep a single vector along pooling_axis at the index embedding_index_to_keep (see argument below), and discard everything else.
pooling_axis (int, default=1) – Axis to apply pooling.
embedding_index_to_keep (int, default=0) – Embedding index to keep when pooling_strategy=’keep_single_index’. This argument has no effect for other pooling strategies.
include_batch_norm (bool, default=False) – If True, include Batch Normalization between Linear and ReLU modules.
dropout (float, default=0.0) – Dropout probability per hidden layer.
- Returns
architecture – Callable that generates the corresponding probing model.
- Return type
t.Callable[[int, …], ProbingModelForSequences]