Welcome to SAM PyTorch’s documentation!¶
Install¶
Stable release
pip3 install sam-pytorch
Latest code
pip3 install git+https://github.com/tourdeml/sam
Sample usage¶
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = SAM(model.parameters(), optim)
def closure():
optim.zero_grad()
loss = model(torch.randn(1,3,64,64)).sum()
loss.backward()
return loss
optim.step(closure)
API reference