curiosidade.probers.probers

Probing model wrappers.

Module Contents

Classes

ProbingModelWrapper

Probing model wrapper.

ProbingModelFactory

Factory to create multiple probing models from a single configuration.

class curiosidade.probers.probers.ProbingModelWrapper(probing_model: torch.nn.Module, task: curiosidade.probers.tasks.base.BaseProbingTask, optim: torch.optim.Optimizer, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None)

Probing model wrapper.

Parameters
  • probing_model (torch.nn.Module) – Probing model.

  • task (task.base.BaseProbingTask) – Probing task related to the probing model.

  • optim (torch.optim.Optimizer) – Optimizer used to train probing_model parameters.

  • lr_scheduler (torch.optim.lr_scheduler._LRScheduler or None, default=None) – Optional learning rate scheduler, coupled with optim.

__repr__(self) str

Return repr(self).

property is_attached(self) bool

Check whether the probing model is attached to a pretrained module.

property has_lr_scheduler(self) bool

Check whether the probing model has a learning rate scheduler.

attach(self, module: torch.nn.Module) ProbingModelWrapper

Attach probing model to module.

detach(self) ProbingModelWrapper

Detach attached prober, if any.

remove(self) ProbingModelWrapper

Alias for ‘detach’.

to(self, device: Union[torch.device, str]) ProbingModelWrapper

Move probing model to device.

step(self, input_labels: torch.Tensor, accumulate_grad: bool = False, is_test: bool = False, compute_metrics: bool = True) dict[str, float]

Perform a single optimization step with input_labels as target reference.

Parameters
  • input_labels (torch.Tensor) – Ground truth labels for current batch.

  • accumulate_grad (bool, default=False) – If True, will not perform gradient cleaning, adding the current backward computation to the pre-existing gradient. This also prevent updates to model weights.

  • is_test (bool, default=False) – If True, does not compute backward gradients is this run. Also prevent weight update, gradient accumulation, and gradient cleaning.

  • compute_metrics (bool, default=True) – If True, compute metrics related to the task for the current batch.

Returns

metrics – Metrics related to the current batch.

Return type

dict[str float]

step_lr_scheduler(self, *args: Any, **kwargs: Any) ProbingModelWrapper

Apply one step of learning rate scheduler.

train(self) ProbingModelWrapper

Set model to train mode.

eval(self) ProbingModelWrapper

Set model to evaluation mode.

class curiosidade.probers.probers.ProbingModelFactory(probing_model_fn: Callable[Ellipsis, torch.nn.Module], task: curiosidade.probers.tasks.base.BaseProbingTask, optim_fn: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = torch.optim.Adam, lr_scheduler_fn: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None, extra_kwargs: Optional[dict[str, Any]] = None)

Factory to create multiple probing models from a single configuration.

Parameters
  • probing_model_fn (t.Callable[..., torch.nn.Module]) – Probing model factory function, or class derived from torch.nn.Module. Must receive its input dimension (an integer) as its first positional argument, and its output dimension (also an integer) as its second positional argument. Extra arguments can be handled via extra_kwargs parameter.

  • task (task.base.BaseProbingTask) – Probing task related to the probing models.

  • optim_fn (t.Type[torch.optim.Optimizer], default=torch.optim.Adam) – Optimizer factory function.

  • lr_scheduler_fn (t.Type[torch.optim.lr_scheduler._LRScheduler], default=None) – If provided, will set up a learning rate scheduler coupled to the optimizer.

  • extra_kwargs (dict[str, t.Any] or None, default=None) – Extra arguments to provide to probing_model_fn.

Examples

>>> import curiosidade
>>> import functools
...
>>> class ProbingModel(torch.nn.Module):
...     def __init__(self, input_dim: int, output_dim: int):
...         super().__init__()
...         self.params = torch.nn.Sequential(
...             torch.nn.Linear(input_dim, 20),
...             torch.nn.ReLU(inplace=True),
...             torch.nn.Linear(20, output_dim),
...         )
...
...     def forward(self, X):
...         return self.params(X)
...
>>> task = curiosidade.probers.base.DummyProbingTask()
>>> ProbingModelFactory(
...     probing_model_fn=ProbingModel,  # Note: do not instantiate.
...     optim_fn=functools.partial(torch.optim.Adam, lr=0.01),  # Note: do not instantiate.
...     task=task,
... )
ProbingModelFactory
  (a): probing model generator : <class 'curiosidade.probers.probers.ProbingModel'>
  (b): optimizer generator : functools.partial(<class 'torch.optim.adam.Adam'>, lr=0.01)
  (c): task : 'unnamed_task' (classification)
__repr__(self) str

Return repr(self).

create_and_attach(self, module: torch.nn.Module, probing_input_dim: Union[int, tuple[int, Ellipsis]], random_seed: Optional[int] = None) ProbingModelWrapper

Create a brand-new probing model and attach it to module.

Parameters
  • module (torch.nn.Module) – Module to attach the probing model.

  • probing_input_dim (int or tuple[int, ...]) – Input dimension of probing model. It should match the output dimension of module. If module has more than one output, must be a tuple with the dimensions of each output, mantaining the order. This tuple will be unpacked before provided to the probing model.

  • random_seed (int or None, default=None) – Random seed set while creating the probing model, mainly to control for random weight initialization, and any other non-deterministic behaviours. Note that this only take into account Torch-related pseudo-randomness. If your model depends on other independent pseudo-random generators (such as random or numpy.random), you must control their behaviour separately within the probing model code (for instance, providing a random seed via ProbingModelFactory.extra_kwargs).

Returns

probing_model – Probing model created and attached to module.

Return type

ProbingModelWrapper

__call__(self, *args: Any, **kwargs: Any) ProbingModelWrapper

Call create_and_attach.