curiosidade.probers.probers
Probing model wrappers.
Module Contents
Classes
Probing model wrapper. |
|
Factory to create multiple probing models from a single configuration. |
- class curiosidade.probers.probers.ProbingModelWrapper(probing_model: torch.nn.Module, task: curiosidade.probers.tasks.base.BaseProbingTask, optim: torch.optim.Optimizer, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None)
Probing model wrapper.
- Parameters
probing_model (torch.nn.Module) – Probing model.
task (task.base.BaseProbingTask) – Probing task related to the probing model.
optim (torch.optim.Optimizer) – Optimizer used to train probing_model parameters.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler or None, default=None) – Optional learning rate scheduler, coupled with optim.
- __repr__(self) str
Return repr(self).
- property is_attached(self) bool
Check whether the probing model is attached to a pretrained module.
- property has_lr_scheduler(self) bool
Check whether the probing model has a learning rate scheduler.
- attach(self, module: torch.nn.Module) ProbingModelWrapper
Attach probing model to module.
- detach(self) ProbingModelWrapper
Detach attached prober, if any.
- remove(self) ProbingModelWrapper
Alias for ‘detach’.
- to(self, device: Union[torch.device, str]) ProbingModelWrapper
Move probing model to device.
- step(self, input_labels: torch.Tensor, accumulate_grad: bool = False, is_test: bool = False, compute_metrics: bool = True) dict[str, float]
Perform a single optimization step with input_labels as target reference.
- Parameters
input_labels (torch.Tensor) – Ground truth labels for current batch.
accumulate_grad (bool, default=False) – If True, will not perform gradient cleaning, adding the current backward computation to the pre-existing gradient. This also prevent updates to model weights.
is_test (bool, default=False) – If True, does not compute backward gradients is this run. Also prevent weight update, gradient accumulation, and gradient cleaning.
compute_metrics (bool, default=True) – If True, compute metrics related to the task for the current batch.
- Returns
metrics – Metrics related to the current batch.
- Return type
dict[str float]
- step_lr_scheduler(self, *args: Any, **kwargs: Any) ProbingModelWrapper
Apply one step of learning rate scheduler.
- train(self) ProbingModelWrapper
Set model to train mode.
- eval(self) ProbingModelWrapper
Set model to evaluation mode.
- class curiosidade.probers.probers.ProbingModelFactory(probing_model_fn: Callable[Ellipsis, torch.nn.Module], task: curiosidade.probers.tasks.base.BaseProbingTask, optim_fn: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = torch.optim.Adam, lr_scheduler_fn: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None, extra_kwargs: Optional[dict[str, Any]] = None)
Factory to create multiple probing models from a single configuration.
- Parameters
probing_model_fn (t.Callable[..., torch.nn.Module]) – Probing model factory function, or class derived from torch.nn.Module. Must receive its input dimension (an integer) as its first positional argument, and its output dimension (also an integer) as its second positional argument. Extra arguments can be handled via extra_kwargs parameter.
task (task.base.BaseProbingTask) – Probing task related to the probing models.
optim_fn (t.Type[torch.optim.Optimizer], default=torch.optim.Adam) – Optimizer factory function.
lr_scheduler_fn (t.Type[torch.optim.lr_scheduler._LRScheduler], default=None) – If provided, will set up a learning rate scheduler coupled to the optimizer.
extra_kwargs (dict[str, t.Any] or None, default=None) – Extra arguments to provide to probing_model_fn.
Examples
>>> import curiosidade >>> import functools ... >>> class ProbingModel(torch.nn.Module): ... def __init__(self, input_dim: int, output_dim: int): ... super().__init__() ... self.params = torch.nn.Sequential( ... torch.nn.Linear(input_dim, 20), ... torch.nn.ReLU(inplace=True), ... torch.nn.Linear(20, output_dim), ... ) ... ... def forward(self, X): ... return self.params(X) ... >>> task = curiosidade.probers.base.DummyProbingTask() >>> ProbingModelFactory( ... probing_model_fn=ProbingModel, # Note: do not instantiate. ... optim_fn=functools.partial(torch.optim.Adam, lr=0.01), # Note: do not instantiate. ... task=task, ... ) ProbingModelFactory (a): probing model generator : <class 'curiosidade.probers.probers.ProbingModel'> (b): optimizer generator : functools.partial(<class 'torch.optim.adam.Adam'>, lr=0.01) (c): task : 'unnamed_task' (classification)
- __repr__(self) str
Return repr(self).
- create_and_attach(self, module: torch.nn.Module, probing_input_dim: Union[int, tuple[int, Ellipsis]], random_seed: Optional[int] = None) ProbingModelWrapper
Create a brand-new probing model and attach it to module.
- Parameters
module (torch.nn.Module) – Module to attach the probing model.
probing_input_dim (int or tuple[int, ...]) – Input dimension of probing model. It should match the output dimension of module. If module has more than one output, must be a tuple with the dimensions of each output, mantaining the order. This tuple will be unpacked before provided to the probing model.
random_seed (int or None, default=None) – Random seed set while creating the probing model, mainly to control for random weight initialization, and any other non-deterministic behaviours. Note that this only take into account Torch-related pseudo-randomness. If your model depends on other independent pseudo-random generators (such as random or numpy.random), you must control their behaviour separately within the probing model code (for instance, providing a random seed via ProbingModelFactory.extra_kwargs).
- Returns
probing_model – Probing model created and attached to module.
- Return type
- __call__(self, *args: Any, **kwargs: Any) ProbingModelWrapper
Call create_and_attach.