Source code for sam.utils

from typing import Iterable, Callable

import torch
from torch.optim import Optimizer


[docs]def compute_sam(group: dict, closure: Callable): grads = [] params_with_grads = [] rho = group['rho'] # update internal_optim's learning rate for p in group['params']: if p.grad is not None: # without clone().detach(), p.grad will be zeroed by closure() grads.append(p.grad.clone().detach()) params_with_grads.append(p) device = grads[0].device # compute \hat{\epsilon}=\rho/\norm{g}\|g\| grad_norm = torch.stack( [g.detach().norm(2).to(device) for g in grads]).norm(2) epsilon = grads # alias for readability torch._foreach_mul_(epsilon, rho / grad_norm) # virtual step toward \epsilon torch._foreach_add_(params_with_grads, epsilon) # compute g=\nabla_w L_B(w)|_{w+\hat{\epsilon}} closure() # virtual step back to the original point torch._foreach_sub_(params_with_grads, epsilon)