sam-pytorch package

Submodules

sam.sam module

class sam.sam.SAM(params: Iterable[torch.Tensor], optim: torch.optim.optimizer.Optimizer, rho: float = 0.05)[source]

Bases: torch.optim.optimizer.Optimizer

SAM wrapper for optimizers

All credits: https://github.com/moskomule/sam.pytorch :param params: tensors to be optimized :type params: Iterable :param optim: PyTorch optimizer :type optim: torch.optim.Optimizer :param rho: Neighbourhood size, default=0.05 :type rho: Float, optional

step(closure)torch.Tensor[source]
Parameters

closure – A closure that reevaluates the model and returns the loss.

Returns: the loss value evaluated on the original point

nfnets.utils module

sam.utils.compute_sam(group: dict, closure: Callable)[source]