Skip to content
BACK_TO_INDEX

OPENAI CHALLENGE // ACTIVE SUBMISSION

PARAMETER_GOLF

Working on a SOTA submission for OpenAI's Model Craft Challenge — a competitive challenge to train the best language model that fits in a 16MB compressed artifact under 10 minutes on 8xH100s. Participants fork the official repo and compete on bits-per-byte compression of the FineWeb validation set. The base architecture features depth-recurrent GPT with U-Net skip connections, SmearGate, and the Muon optimizer.

PYTHON PYTORCH OPENAI CHALLENGE 8xH100 MODEL COMPRESSION
AI processor chip with orange circuit traces on dark board

TECH_STACK

Python PyTorch CUDA Flash Attention MLX SentencePiece DDP torch.compile RunPod

KEY_FEATURES

Depth-Recurrent Transformer

Reuses a small set of unique blocks in a loop — 3 blocks looped 3x = 9 effective layers without adding parameters. U-Net skip connections with learned per-dimension weights.

Muon Optimizer + Polar Express

Custom optimizer that orthogonalizes gradients via Newton-Schulz iterations with optimized polynomial coefficients, plus cautious weight decay.

Multi-Precision Quantization

GPTQ-lite post-training quantization supporting int8/int6/int5 with per-row calibration across multiple clip percentiles. QAT during final training phase.

SmearGate + BigramHash

Learned gate blending adjacent token embeddings for cheap bigram context before attention, plus hash-based bigram lookup for local co-occurrence signals.

SOURCE_CODE

train_gpt.py DEPTH-RECURRENT GPT
def _trunk(self, input_ids, bigram_ids=None):
    x = self.tok_emb(input_ids)

    # SmearGate: blend current/previous token embeddings
    if self.smeargate:
        gate = torch.sigmoid(self.smear_gate).to(dtype=x.dtype)
        x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
        x = (1 - gate) * x + gate * x_prev

    # BigramHash injection
    if bigram_ids is not None:
        x = x + self.bigram_proj(self.bigram_embed(bigram_ids))

    x = F.rms_norm(x, (x.size(-1),))
    x0 = x
    num_enc = self.num_layers // 2
    skips: list[Tensor] = []

    # U-Net: encoder stores, decoder consumes skip connections
    for block_idx, block in enumerate(self.blocks):
        x = block(x, x0)
        if block_idx < num_enc:
            skips.append(x)
        elif skips:
            w = self.skip_weights[len(skips)-1]
            x = x + w.to(dtype=x.dtype)[None,None,:] * skips.pop()

    return self.final_norm(x)
train_gpt.py POLAR EXPRESS ORTHOGONALIZATION
_POLAR_EXPRESS_COEFFS = [
    (8.157, -22.483, 15.879),
    (4.043, -2.809,  0.500),
    (3.892, -2.772,  0.506),
    (3.286, -2.368,  0.464),
    (2.347, -1.710,  0.423),
]

def zeropower_via_newtonschulz5(G, steps=10):
    """Orthogonalize gradient via Newton-Schulz iterations"""
    X = G.bfloat16()
    X /= (X.norm() + 1e-7) * 1.02

    if G.size(0) > G.size(1):
        for a, b, c in _POLAR_EXPRESS_COEFFS:
            A = X.T @ X
            B = b * A + c * A @ A
            X = a * X + X @ B
    else:
        for a, b, c in _POLAR_EXPRESS_COEFFS:
            A = X @ X.T
            B = b * A + c * A @ A
            X = a * X + B @ X

    return X
train_gpt.py MULTI-PRECISION QUANTIZATION
def quantize_float_tensor(t, bits=8):
    """GPTQ-lite: try multiple clip percentiles, pick best per row"""
    max_val = (1 << (bits - 1)) - 1
    candidates = [0.999, 0.9995, 0.9999, 0.99999, 1.0]

    best_q, best_scale = None, None
    best_mse = torch.full((t.shape[0],), float('inf'))

    for pct in candidates:
        ca = torch.quantile(t.abs(), pct, dim=1)
        s = (ca / max_val).clamp_min(1e-10)
        clipped = torch.clamp(t, -ca[:,None], ca[:,None])
        q = torch.round(clipped / s[:,None]).clamp(-max_val, max_val)
        mse = ((t - q * s[:,None]) ** 2).mean(dim=1)

        improved = mse < best_mse
        best_q[improved] = q[improved]
        best_scale[improved] = s[improved]

    return best_q.to(torch.int8), best_scale.to(torch.float16)

ARCHITECTURE

Model Design

  • Depth-recurrent block reuse
  • U-Net skip connections
  • SmearGate token blending
  • BigramHash embeddings
  • Factored embedding matrix

Training

  • Muon optimizer + Polar Express
  • WSD-S learning rate schedule
  • EMA + LAWA weight averaging
  • QAT in final phase
  • 8xH100 DDP via torchrun

Compression

  • int8/int6/int5 quantization
  • Per-row GPTQ-lite calibration
  • Tied embeddings
  • zlib/zstandard compression
  • 16MB budget calculator
BACK_TO_INDEX