Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
# ---------- Newton–Schulz orthogonaliser ----------def zeropower_via_newtonschulz5(G: Tensor, steps: int=5) -> Tensor: a, b, c =3.4445, -4.7750, 2.0315 X = G.to(torch.bfloat16)if X.size(-2) > X.size(-1): X = X.mT X /= X.norm(dim=(-2, -1), keepdim=True).clamp(min=1e-7)for _ inrange(steps): A = X @ X.mT X = a * X + (b * A + c * A @ A) @ Xreturn (X.mT if G.size(-2) > X.size(-1) else X).to(G.dtype)# ---------- single-device Muon ----------class SimpleMuon(torch.optim.Optimizer):def__init__(self, params, lr=0.02, momentum=0.95, weight_decay=0.01, nesterov=True, ns_steps=5):super().__init__(params, dict(lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov, ns_steps=ns_steps))@torch.no_grad()def step(self, closure=None):if closure isnotNone:with torch.enable_grad(): closure()for g inself.param_groups: lr, mom, wd, nest, k = (g[p] for p in ("lr", "momentum", "weight_decay","nesterov", "ns_steps"))for p in g["params"]:if p.grad isNone:continue grad = p.grad.add(p, alpha=wd) if wd else p.grad buf =self.state.setdefault(p, {}).setdefault("momentum_buffer", torch.zeros_like(p)) buf.mul_(mom).add_(grad) d_p = grad.add(buf, alpha=mom) if nest else bufif p.ndim ==4: flat = d_p.view(p.size(0), -1) d_p = zeropower_via_newtonschulz5(flat, k).view_as(p)elif p.ndim >=2: d_p = zeropower_via_newtonschulz5(d_p, k) p.add_(d_p, alpha=-lr)# ---------- synthetic binary-classification data ----------torch.manual_seed(42)N, D =5_000, 100true_w = torch.randn(D)X = torch.randn(N, D)y = torch.bernoulli(torch.sigmoid(X @ true_w)).float() # logistic model :contentReference[oaicite:2]{index=2}loader = DataLoader(TensorDataset(X, y), batch_size=128, shuffle=True)# ---------- logistic-regression model ----------class LogReg(nn.Module):def__init__(self, dim):super().__init__() s =int(math.isqrt(dim))assert s * s == dimself.W = nn.Parameter(torch.randn(s, s) *0.01) # 2-D → Muon path :contentReference[oaicite:3]{index=3}self.b = nn.Parameter(torch.zeros(()))def forward(self, x): return torch.sigmoid(x @self.W.flatten() +self.b)# ---------- training helper (now takes *list* of optimisers) ----------def train(model, opts, epochs=15): loss_fn = nn.BCELoss() # classic but stable :contentReference[oaicite:4]{index=4}for ep inrange(1, epochs +1):for xb, yb in loader: loss = loss_fn(model(xb).squeeze(), yb)for o in opts: o.zero_grad(set_to_none=True) loss.backward()for o in opts: o.step()if ep %5==0:with torch.no_grad(): acc = ((model(X).squeeze() >0.5) == y).float().mean()print(f"epoch {ep:2d} | loss={loss.item():.4f} | acc={acc:.3f}")# ---------- run both experiments ----------init = LogReg(D).state_dict()print("===> SimpleMuon + AdamW(scalar)")mu_model = LogReg(D); mu_model.load_state_dict(init)mu_opt = SimpleMuon([mu_model.W])sc_opt = torch.optim.AdamW([mu_model.b], lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01) # AdamW default :contentReference[oaicite:5]{index=5}train(mu_model, [mu_opt, sc_opt])print("\n===> AdamW only")ad_model = LogReg(D); ad_model.load_state_dict(init)ad_opt = torch.optim.AdamW(ad_model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01)train(ad_model, [ad_opt])