!pip install git+https://github.com/KellerJordan/Muon
import math, torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, TensorDataset
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
# ---------- Newton–Schulz orthogonaliser ----------
def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5) -> Tensor:
    a, b, c = 3.4445, -4.7750, 2.0315
    X = G.to(torch.bfloat16)
    if X.size(-2) > X.size(-1):
        X = X.mT
    X /= X.norm(dim=(-2, -1), keepdim=True).clamp(min=1e-7)
    for _ in range(steps):
        A = X @ X.mT
        X = a * X + (b * A + c * A @ A) @ X
    return (X.mT if G.size(-2) > X.size(-1) else X).to(G.dtype)

# ---------- single-device Muon ----------
class SimpleMuon(torch.optim.Optimizer):
    def __init__(self, params, lr=0.02, momentum=0.95,
                 weight_decay=0.01, nesterov=True, ns_steps=5):
        super().__init__(params, dict(lr=lr, momentum=momentum,
                                      weight_decay=weight_decay,
                                      nesterov=nesterov, ns_steps=ns_steps))

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None:
            with torch.enable_grad():
                closure()
        for g in self.param_groups:
            lr, mom, wd, nest, k = (g[p] for p in
                                    ("lr", "momentum", "weight_decay",
                                     "nesterov", "ns_steps"))
            for p in g["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.add(p, alpha=wd) if wd else p.grad
                buf = self.state.setdefault(p, {}).setdefault(
                    "momentum_buffer", torch.zeros_like(p))
                buf.mul_(mom).add_(grad)
                d_p = grad.add(buf, alpha=mom) if nest else buf
                if p.ndim == 4:
                    flat = d_p.view(p.size(0), -1)
                    d_p = zeropower_via_newtonschulz5(flat, k).view_as(p)
                elif p.ndim >= 2:
                    d_p = zeropower_via_newtonschulz5(d_p, k)
                p.add_(d_p, alpha=-lr)

# ---------- synthetic binary-classification data ----------
torch.manual_seed(42)
N, D = 5_000, 100
true_w = torch.randn(D)
X = torch.randn(N, D)
y = torch.bernoulli(torch.sigmoid(X @ true_w)).float()   # logistic model :contentReference[oaicite:2]{index=2}
loader = DataLoader(TensorDataset(X, y), batch_size=128, shuffle=True)

# ---------- logistic-regression model ----------
class LogReg(nn.Module):
    def __init__(self, dim):
        super().__init__()
        s = int(math.isqrt(dim))
        assert s * s == dim
        self.W = nn.Parameter(torch.randn(s, s) * 0.01)   # 2-D → Muon path :contentReference[oaicite:3]{index=3}
        self.b = nn.Parameter(torch.zeros(()))
    def forward(self, x): return torch.sigmoid(x @ self.W.flatten() + self.b)

# ---------- training helper (now takes *list* of optimisers) ----------
def train(model, opts, epochs=15):
    loss_fn = nn.BCELoss()                                # classic but stable :contentReference[oaicite:4]{index=4}
    for ep in range(1, epochs + 1):
        for xb, yb in loader:
            loss = loss_fn(model(xb).squeeze(), yb)
            for o in opts: o.zero_grad(set_to_none=True)
            loss.backward()
            for o in opts: o.step()
        if ep % 5 == 0:
            with torch.no_grad():
                acc = ((model(X).squeeze() > 0.5) == y).float().mean()
            print(f"epoch {ep:2d} | loss={loss.item():.4f} | acc={acc:.3f}")

# ---------- run both experiments ----------
init = LogReg(D).state_dict()

print("===> SimpleMuon + AdamW(scalar)")
mu_model = LogReg(D); mu_model.load_state_dict(init)
mu_opt   = SimpleMuon([mu_model.W])
sc_opt   = torch.optim.AdamW([mu_model.b], lr=3e-4, betas=(0.9, 0.95),
                             weight_decay=0.01)           # AdamW default :contentReference[oaicite:5]{index=5}
train(mu_model, [mu_opt, sc_opt])

print("\n===> AdamW only")
ad_model = LogReg(D); ad_model.load_state_dict(init)
ad_opt   = torch.optim.AdamW(ad_model.parameters(), lr=3e-4, betas=(0.9, 0.95),
                             weight_decay=0.01)
train(ad_model, [ad_opt])
===> SimpleMuon + AdamW(scalar)
epoch  5 | loss=0.0364 | acc=0.946
epoch 10 | loss=0.1867 | acc=0.943
epoch 15 | loss=0.2646 | acc=0.941

===> AdamW only
epoch  5 | loss=0.6231 | acc=0.874
epoch 10 | loss=0.5356 | acc=0.903
epoch 15 | loss=0.4566 | acc=0.919