Welcome to SAM PyTorch’s documentation!

Install

Stable release

pip3 install sam-pytorch

Latest code

pip3 install git+https://github.com/tourdeml/sam

Sample usage

model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = SAM(model.parameters(), optim)
def closure():
   optim.zero_grad()
   loss = model(torch.randn(1,3,64,64)).sum()
   loss.backward()
   return loss
optim.step(closure)