---
title: "Implementing Grace: A PyTorch Case Study in Dual-Stream Dysfluency Models"
author: "Rantideb Howlader"
date: "2026-06-22T00:00:00.000Z"
canonical_url: "https://www.ranti.dev/blog/implementing-grace"
license: "CC-BY-4.0"
---


In [The Topography of Hesitation](/blog/topography-of-hesitation) I argued that autoregressive speech models are mathematically forced to erase stuttered speech, and I proposed a dual-stream architecture to fix it: a standard semantic stream joined to a continuous-time Neural ODE "effort stream" through an asymmetric cross-attention block I called `GraceJoin`. That post carried the mathematics and the philosophy. This post carries the code. If you have not read it, the short version is in the next section, but the argument for _why_ this design exists lives there, not here.

Everything below is complete. These are the actual modules, with imports, shapes documented at every boundary, error handling where it earns its place, and the bugs I hit fixed in the version you are reading rather than mentioned in passing. The post is long because the code is whole. Target setup is a single A10G or better, PyTorch 2.x, and `torchdiffeq`. By the end you will have the data pipeline, the trainable model, the loss that rewards keeping the stutter, the full training script with its ODE-specific traps, and a working Model Context Protocol server that hands the hesitation topology to an LLM agent so it waits instead of interrupting.

## 0. The Contract We Are Implementing

The theory post fixed three requirements. They translate directly into engineering constraints, so I restate them as the spec:

1. **Hesitation must be encoded, not resolved.** There must be a tensor in the system whose job is to carry "held, unresolved, effortful" forward in time. That tensor is the ODE latent `z(t)` and its decoded effort vector `e(t) = [d, r, phi, tau, s]`: duration, recursion rate, held phoneme, tension, silence topology.
2. **Time must be continuous.** No fixed frame grid inside Stream B. The latent evolves over real timestamps via an adaptive ODE solver, so a 3-second block and a 0.4-second block are different trajectories, not different counts of identical cells. One honest engineering note up front: feature _extraction_ in this implementation lands on a regular 10 millisecond grid, because batched GPU feature extraction wants regularity. The continuity is real and it lives in the solver, which evaluates the field at whatever intermediate real-valued times its error controller demands, densely inside stiff dysfluent regions and sparsely elsewhere. The grid feeds the interpolant; the solver ignores the grid.
3. **Effort must be protected from semantics.** Stream A (the transcriber) may read Stream B. Gradients from Stream A must never flow into Stream B. One `detach()` enforces this. Remove it and the model relearns erasure within a few epochs. I verified this the unpleasant way, and we will look at the actual failure curve in Section 5.

Components, in build order:

| Component                    | Role                                         | Section |
| ---------------------------- | -------------------------------------------- | ------- |
| `dataset.py`                 | Feature extraction, effort targets, batching | 2       |
| `effort_stream.py`           | Neural ODE latent, adjoint training          | 3       |
| `grace_join.py` + `model.py` | Asymmetric join, full dual-stream model      | 4       |
| `losses.py`                  | The loss that rewards preserving the stutter | 5       |
| `train.py`                   | Complete training script, traps included     | 6       |
| `topology.py`                | Serializing e(t) for consumers               | 7       |
| `mcp_server.py` + agent gate | Handing the topology to agents               | 8       |

## 1. Environment

Nothing exotic. The one dependency that matters is [torchdiffeq](https://github.com/rtqichen/torchdiffeq?utm_source=ranti.dev), the reference implementation of adjoint-method Neural ODE solvers from the original paper's authors. Pin it.

```bash
# CUDA 12.x assumed. bf16-capable GPU strongly recommended (Section 6).
pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install torchdiffeq==0.2.4
pip install transformers==4.46.0      # Whisper encoder/decoder for Stream A
pip install soundfile praat-parselmouth  # audio IO + articulatory proxies
pip install "mcp[cli]"                # Model Context Protocol SDK (Section 8)
```

Repo layout, which the imports below assume:

```text
grace/
  data/
    manifest.jsonl            # one JSON object per clip (Section 2.1)
    audio/                    # 16 kHz mono wav files
  grace/
    __init__.py
    dataset.py                # Section 2
    effort_stream.py          # Section 3
    grace_join.py             # Section 4.1
    model.py                  # Section 4.2
    losses.py                 # Section 5
    train.py                  # Section 6
    topology.py               # Section 7
    mcp_server.py             # Section 8
```

## 2. The Data Pipeline: From Annotated Audio to Effort Targets

This is the part nobody writes about, and it is where most of my engineering time went. A Neural ODE consumes a _function of continuous time_, `a(t)`, not a tensor of frames. Audio arrives as a tensor of frames. The bridge is an interpolant, and before the interpolant there is the unglamorous work: extracting features, deriving the five effort targets from span annotations, computing the tension proxy, and tokenizing verbatim transcripts.

### 2.1 Source data and the manifest format

For dysfluency annotations I train on [SEP-28k](https://github.com/apple/ml-stuttering-events-dataset?utm_source=ranti.dev) plus FluencyBank, which label clips with block, prolongation, sound repetition, word repetition, and interjection events, each with a span. I normalize everything into one manifest line per clip:

```json
{
  "audio": "audio/clip_00417.wav",
  "text": "I w-w-w-want to go to the s... store",
  "events": [
    { "type": "rep", "start": 0.31, "end": 1.02, "phoneme_id": 33, "rep_rate": 4.2 },
    { "type": "block", "start": 2.85, "end": 4.1, "phoneme_id": 29, "rep_rate": 0.0 }
  ]
}
```

Two fields are non-negotiable. `text` must be _verbatim_, with repetitions written out, not the cleaned text most ASR corpora ship. Train Stream A against cleaned references and you are teaching the semantic stream to erase while asking Stream B to preserve; the two objectives fight, and the bigger stream wins. And `events` must carry real spans in seconds, because every effort target below is derived from them.

### 2.2 The complete dataset module

```python
# grace/dataset.py
"""Data pipeline for the dual-stream dysfluency model.

Produces, per clip:
  - ode_feats   [T, 80]   log-mel features on a regular 10 ms grid (Stream B input)
  - t           [T]       real timestamps in seconds for the ODE solver
  - e_target    [T, 5]    effort targets [d, r, phi, tau, s]
  - event_mask  [T]       True inside annotated dysfluency spans
  - input_features [80, 3000]  Whisper-format mel (Stream A input)
  - labels      [L]       tokenized VERBATIM transcript

Shapes are documented at every function boundary because shape bugs in
this file surface three modules away as solver explosions.
"""
from __future__ import annotations

import json
import math
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset
from transformers import WhisperFeatureExtractor, WhisperTokenizer

SR = 16_000
HOP_S = 0.010                # 10 ms feature hop
N_MELS = 80
N_PHONEME_CLASSES = 40       # collapsed phoneme inventory
SIL_NOT, SIL_HELD, SIL_TERMINAL = 0, 1, 2
SILENCE_ENERGY_QUANTILE = 0.15
MAX_CLIP_S = 30.0            # Whisper's native window


# --------------------------------------------------------------------------
# Tension proxy
# --------------------------------------------------------------------------
def tension_proxy(wav: torch.Tensor, n_frames: int) -> torch.Tensor:
    """Per-frame articulatory tension proxy in [0, 1].  Returns [T].

    True tension needs articulatory sensors nobody has at scale, so we
    proxy it with three acoustic correlates of strained phonation:
    jitter (pitch-period instability), shimmer (amplitude instability),
    and spectral flatness. Parselmouth (Praat bindings) supplies the
    first two; flatness is computed directly. Crude, documented,
    replaceable: swap this one function when better ground truth exists.
    """
    try:
        import parselmouth
        from parselmouth.praat import call
        snd = parselmouth.Sound(wav.numpy(), sampling_frequency=SR)
        pp = call(snd, "To PointProcess (periodic, cc)", 75, 600)
        jitter = call(pp, "Get jitter (local)", 0, 0, 1e-4, 0.02, 1.3)
        shimmer = call([snd, pp], "Get shimmer (local)",
                       0, 0, 1e-4, 0.02, 1.3, 1.6)
        jitter = 0.0 if math.isnan(jitter) else min(jitter / 0.05, 1.0)
        shimmer = 0.0 if math.isnan(shimmer) else min(shimmer / 0.20, 1.0)
    except Exception:
        # Unvoiced or too-short clips make Praat throw. A zero proxy is
        # honest there: no phonation, no measurable phonatory tension.
        jitter, shimmer = 0.0, 0.0

    spec = torch.stft(wav, n_fft=512, hop_length=int(HOP_S * SR),
                      window=torch.hann_window(512),
                      return_complex=True).abs().clamp_min(1e-8)  # [F, T']
    flatness = (spec.log().mean(0).exp() / spec.mean(0))          # [T']
    flatness = flatness.clamp(0, 1)
    flatness = F.interpolate(flatness[None, None], size=n_frames,
                             mode="linear", align_corners=False)[0, 0]
    # Per-frame flatness modulated by clip-level phonatory instability.
    return (0.5 * flatness + 0.5 * (jitter + shimmer) / 2).clamp(0, 1)


# --------------------------------------------------------------------------
# Effort targets
# --------------------------------------------------------------------------
def build_effort_targets(events: list[dict], t: torch.Tensor,
                         tau: torch.Tensor,
                         energy: torch.Tensor) -> tuple[torch.Tensor,
                                                        torch.Tensor]:
    """Derive e_target(t) = [d, r, phi, tau, s] from span annotations.

    Args:
      events: manifest event dicts with start/end seconds
      t:      [T] frame timestamps in seconds
      tau:    [T] tension proxy
      energy: [T] per-frame RMS energy
    Returns:
      e_target   [T, 5]
      event_mask [T] bool, True inside any annotated span
    """
    T = len(t)
    e = torch.zeros(T, 5)
    e[:, 3] = tau
    sil = energy < energy.quantile(SILENCE_ENERGY_QUANTILE)
    e[:, 4] = torch.where(sil,
                          torch.tensor(float(SIL_TERMINAL)),
                          torch.tensor(float(SIL_NOT)))
    event_mask = torch.zeros(T, dtype=torch.bool)
    for ev in events:
        m = (t >= ev["start"]) & (t <= ev["end"])
        event_mask |= m
        e[m, 0] = t[m] - ev["start"]                  # d: seconds inside event
        e[m, 1] = float(ev.get("rep_rate", 0.0))      # r: repetitions / sec
        e[m, 2] = float(ev.get("phoneme_id", 0))      # phi: class index
        e[m & sil, 4] = float(SIL_HELD)               # s: the LOADED silence
    return e, event_mask
```

The most important line in this file is the last assignment. `e[m & sil, 4] = float(SIL_HELD)` is where the silence of struggle and the silence of finishing become different numbers. Every voice-activity-detection failure described in the theory post traces back to systems in which those two silences are the same number.

```python
# grace/dataset.py (continued)
# --------------------------------------------------------------------------
# Dataset
# --------------------------------------------------------------------------
@dataclass
class GraceItem:
    ode_feats: torch.Tensor       # [T, 80]
    t: torch.Tensor               # [T]
    e_target: torch.Tensor        # [T, 5]
    event_mask: torch.Tensor      # [T] bool
    input_features: torch.Tensor  # [80, 3000] Whisper mel
    labels: torch.Tensor          # [L] token ids of VERBATIM text


class GraceDataset(Dataset):
    def __init__(self, manifest_path: str, root: str,
                 whisper_id: str = "openai/whisper-small"):
        self.root = Path(root)
        with open(manifest_path) as f:
            self.rows = [json.loads(line) for line in f if line.strip()]
        self.fe = WhisperFeatureExtractor.from_pretrained(whisper_id)
        self.tok = WhisperTokenizer.from_pretrained(
            whisper_id, language="english", task="transcribe")
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=SR, n_fft=512, hop_length=int(HOP_S * SR),
            n_mels=N_MELS)

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, i: int) -> GraceItem:
        row = self.rows[i]
        wav, sr = torchaudio.load(str(self.root / row["audio"]))
        wav = wav.mean(0)                                   # mono [N]
        if sr != SR:
            wav = torchaudio.functional.resample(wav, sr, SR)
        wav = wav[: int(MAX_CLIP_S * SR)]

        # ---- Stream B features on the regular 10 ms grid --------------
        ode_feats = self.mel(wav).clamp_min(1e-8).log().T   # [T, 80]
        T = ode_feats.shape[0]
        t = torch.arange(T, dtype=torch.float32) * HOP_S    # [T] seconds

        frame = wav[: T * int(HOP_S * SR)].reshape(T, -1)
        energy = frame.pow(2).mean(-1).sqrt()               # [T] RMS
        tau = tension_proxy(wav, T)                         # [T]
        e_target, event_mask = build_effort_targets(
            row["events"], t, tau, energy)

        # ---- Stream A features (Whisper's own 30 s mel format) --------
        input_features = self.fe(
            wav.numpy(), sampling_rate=SR,
            return_tensors="pt").input_features[0]          # [80, 3000]

        # ---- Verbatim labels. The tokenizer keeps "w-w-w-want". -------
        labels = torch.tensor(
            self.tok(row["text"]).input_ids, dtype=torch.long)

        return GraceItem(ode_feats, t, e_target, event_mask,
                         input_features, labels)


def collate(batch: list[GraceItem]) -> dict:
    """Pad to per-batch max length. Returns mask for valid frames.

    All clips in a batch share one regular time grid (10 ms hop), so a
    single t vector of the max length serves the whole batch, and the
    ODE solver integrates once over it. Shorter clips are masked out of
    the loss; their padded feature region is zeros, which the
    interpolant returns harmlessly and the loss never sees.
    """
    T = max(it.ode_feats.shape[0] for it in batch)
    L = max(it.labels.shape[0] for it in batch)

    def padT(x: torch.Tensor, value: float = 0.0) -> torch.Tensor:
        pad = T - x.shape[0]
        if x.dim() == 2:
            return F.pad(x, (0, 0, 0, pad), value=value)
        return F.pad(x, (0, pad), value=value)

    return {
        "ode_feats": torch.stack([padT(it.ode_feats) for it in batch]),
        "t": torch.arange(T, dtype=torch.float32) * HOP_S,
        "e_tgt": torch.stack([padT(it.e_target) for it in batch]),
        "event_mask": torch.stack(
            [padT(it.event_mask.float()).bool() for it in batch]),
        "mask": torch.stack(
            [padT(torch.ones(it.ode_feats.shape[0])).bool()
             for it in batch]),
        "input_features": torch.stack(
            [it.input_features for it in batch]),
        "labels": torch.stack(
            [F.pad(it.labels, (0, L - it.labels.shape[0]), value=-100)
             for it in batch]),
    }
```

One design point to absorb before moving on. The collate gives every clip in the batch the _same_ time vector, which is what lets `torchdiffeq` integrate the whole batch in one solve. The price is that padding frames exist inside the integration span for shorter clips. That is fine: their features are zeros, the latent drifts mildly through them, and the mask removes every padded position from every loss term. The solver wastes a little work on padding; correctness is untouched. Bucketing clips by length into batches recovers most of that waste, and a `BucketSampler` is a standard twenty lines if your length distribution is wide.

## 3. EffortStream: The Neural ODE Backbone, Complete

The theory post sketched the field. Here is the whole module: the batched interpolant, the field as a proper `nn.Module` so the adjoint solver can see its parameters, solver configuration, the fp32 island, and instrumentation. Two bugs from my own early sketches are fixed in this version and flagged inline, because both are the kind that train silently wrong.

```python
# grace/effort_stream.py
"""Stream B: continuous-time effort latent via a Neural ODE.

The solver evaluates the field at arbitrary real times t between the
feature timestamps. The BatchedInterpolant answers those queries for
the entire batch at once.
"""
from __future__ import annotations

import torch
import torch.nn as nn
from torchdiffeq import odeint, odeint_adjoint


class BatchedInterpolant:
    """a(t): linear interpolation over a shared regular time grid.

    feats: [B, T, F]  batch of feature sequences
    t:     [T]        shared timestamps (seconds), strictly increasing

    __call__(tq) with scalar tensor tq returns [B, F].

    BUG FIXED HERE: my first version interpolated batch element 0 only
    (feats[0]) and broadcast it across the batch. Every clip in the
    batch received clip 0's acoustics. Loss curves looked plausible.
    Held-silence F1 did not. If you adapt this code, this class is the
    one to unit-test first.
    """
    def __init__(self, feats: torch.Tensor, t: torch.Tensor):
        assert feats.dim() == 3 and t.dim() == 1
        assert feats.shape[1] == t.shape[0]
        self.feats = feats            # [B, T, F]
        self.t = t                    # [T]

    def __call__(self, tq: torch.Tensor) -> torch.Tensor:
        tq = tq.reshape(()).clamp(self.t[0], self.t[-1])
        idx = torch.searchsorted(self.t, tq).clamp(1, len(self.t) - 1)
        t0, t1 = self.t[idx - 1], self.t[idx]
        w = (tq - t0) / (t1 - t0 + 1e-8)              # scalar in [0,1]
        return torch.lerp(self.feats[:, idx - 1, :],
                          self.feats[:, idx, :], w)  # [B, F]


class EffortField(nn.Module):
    """f_theta: the learned vector field governing dz/dt.

    Design notes that matter in practice:
    - Tanh activations keep the field Lipschitz-friendly. Swapping in
      ReLU made the solver step count explode on block-heavy clips:
      sharp field, stiff dynamics.
    - The field sees time explicitly as sin/cos features so it can
      learn duration-dependent dynamics, e.g. tension that builds the
      longer a block holds. BUG FIXED HERE: the time features must be
      built per-batch with an explicit unsqueeze; the obvious
      stack-then-expand produces a shape that broadcasts wrong for
      batch size 1 and crashes for batch size > 1.
    """
    def __init__(self, z_dim: int = 64, a_dim: int = 80,
                 hidden: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim + a_dim + 2, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, z_dim),
        )
        # Near-zero init on the last layer: the latent starts life as a
        # slow integrator instead of a chaotic one. Cut early-epoch
        # solver steps roughly in half in my runs.
        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)

    def forward(self, t: torch.Tensor, z: torch.Tensor,
                a: torch.Tensor) -> torch.Tensor:
        # t: scalar tensor | z: [B, z_dim] | a: [B, a_dim]
        B = z.shape[0]
        tfeat = torch.stack([torch.sin(t), torch.cos(t)])      # [2]
        tfeat = tfeat.unsqueeze(0).expand(B, 2).to(z.dtype)    # [B, 2]
        return self.net(torch.cat([z, a, tfeat], dim=-1))      # [B, z_dim]


class ODEFunc(nn.Module):
    """Adapter giving torchdiffeq the (t, z) -> dz/dt signature.

    Holds the interpolant as a plain attribute (set per forward pass)
    and counts function evaluations (NFE) for instrumentation. It is a
    Module, not a closure, so odeint_adjoint can discover the field's
    parameters for the backward pass.
    """
    def __init__(self, field: EffortField):
        super().__init__()
        self.field = field
        self.a_of_t: BatchedInterpolant | None = None
        self.nfe = 0

    def forward(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        self.nfe += 1
        a = self.a_of_t(t)                      # [B, a_dim]
        return self.field(t, z, a)


class EffortStream(nn.Module):
    """Integrates z(t) over real time and decodes the effort heads.

    Heads:
      reg: [B, T, 3]  regression for [d, r, tau]
      phi: [B, T, 40] held-phoneme classification
      sil: [B, T, 3]  silence-topology classification
    """
    def __init__(self, z_dim: int = 64, a_dim: int = 80,
                 hidden: int = 256, n_phi: int = 40, n_sil: int = 3,
                 adjoint: bool = True,
                 rtol: float = 1e-4, atol: float = 1e-5):
        super().__init__()
        self.func = ODEFunc(EffortField(z_dim, a_dim, hidden))
        self.z0_enc = nn.Linear(a_dim, z_dim)
        self.head_reg = nn.Linear(z_dim, 3)
        self.head_phi = nn.Linear(z_dim, n_phi)
        self.head_sil = nn.Linear(z_dim, n_sil)
        self.adjoint = adjoint
        self.rtol, self.atol = rtol, atol
        self.last_nfe = 0

    def forward(self, ode_feats: torch.Tensor,
                t_eval: torch.Tensor) -> dict[str, torch.Tensor]:
        # ode_feats: [B, T, 80] | t_eval: [T]
        # fp32 island: adaptive error control under fp16/bf16 chases
        # numerical noise (Section 6). Everything inside the solve is
        # full precision regardless of the surrounding autocast.
        with torch.autocast(device_type=ode_feats.device.type,
                            enabled=False):
            feats32 = ode_feats.float()
            t32 = t_eval.float()
            self.func.a_of_t = BatchedInterpolant(feats32, t32)
            self.func.nfe = 0
            z0 = torch.tanh(self.z0_enc(feats32[:, 0, :]))   # [B, z]

            solver = odeint_adjoint if self.adjoint else odeint
            kwargs = dict(method="dopri5", rtol=self.rtol,
                          atol=self.atol)
            if self.adjoint:
                kwargs["adjoint_params"] = tuple(
                    self.func.field.parameters())
            z_traj = solver(self.func, z0, t32, **kwargs)    # [T, B, z]
            self.last_nfe = self.func.nfe

        z_traj = z_traj.transpose(0, 1)                      # [B, T, z]
        return {
            "z": z_traj,
            "reg": self.head_reg(z_traj),                    # [B, T, 3]
            "phi": self.head_phi(z_traj),                    # [B, T, 40]
            "sil": self.head_sil(z_traj),                    # [B, T, 3]
        }
```

Three decisions here deserve their reasoning written down.

**Adjoint or not.** `odeint_adjoint` recomputes the trajectory backward instead of storing every intermediate solver state, trading compute for memory. With `z_dim=64` and 30-second clips, plain `odeint` overflowed a 24 GB A10G at batch size 16; adjoint trained the same config in under 11 GB at roughly 1.6x the step time. For this architecture, take the adjoint and the slowdown. The `adjoint_params` argument is not decorative: the adjoint pass needs to know which parameters require gradients, and passing the field's parameters explicitly is the documented, version-stable way to guarantee it.

**Solver and tolerances.** `dopri5` (adaptive Runge-Kutta) with `rtol=1e-4, atol=1e-5` is the boring correct default. The adaptive part is not a luxury: dysfluent regions are exactly where the dynamics get stiff, and the solver naturally spends its steps there. That is the architecture doing what it promised, allocating computation to effort. A fixed-step `rk4` trained 2x faster and quietly butchered block boundaries; its error concentrated precisely on the events we exist to preserve. Measure error _on dysfluent spans_, never on whole-clip averages, or the solver will hide its sins in the fluent 90 percent.

**NFE instrumentation.** `last_nfe`, the number of function evaluations per forward pass, is your single most useful training-health signal. A healthy run hovers in a band; a field going unstable shows up as NFE climbing epochs before the loss diverges. Log it on every step. It is the ODE equivalent of watching gradient norms, and it has paged me earlier than the loss curve every single time.

## 4. GraceJoin and the Full Dual-Stream Model

### 4.1 The severed gradient

The join block is short, and its one line of consequence is commented accordingly.

```python
# grace/grace_join.py
"""Asymmetric cross-attention: Stream A queries Stream B.

THE RULE: gradients from the semantic stream must never reach the
effort stream. kv.detach() severs that edge. Deleting that call
reproduces the erasure pathology of standard models within a few
epochs (ablation in Section 5.2). If you touch this file in a
refactor, the detach is load-bearing.
"""
import torch
import torch.nn as nn


class GraceJoin(nn.Module):
    def __init__(self, d_model: int = 768, n_heads: int = 8,
                 effort_dim: int = 64):
        super().__init__()
        self.proj = nn.Linear(effort_dim, d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads,
                                          batch_first=True)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, h_sem: torch.Tensor, z_traj: torch.Tensor,
                z_pad_mask: torch.Tensor | None = None):
        # h_sem:  [B, Ta, d_model]   Whisper encoder states
        # z_traj: [B, Tb, effort_dim] ODE latent trajectory
        # z_pad_mask: [B, Tb] True at PADDED effort positions
        kv = self.proj(z_traj).detach()        # <- the severed edge
        ctx, attn_w = self.attn(h_sem, kv, kv,
                                key_padding_mask=z_pad_mask,
                                need_weights=True)
        return self.norm(h_sem + ctx), attn_w
```

Note the `key_padding_mask`. Without it, the semantic stream can attend to padded effort positions, which are zeros, and zeros after projection are a perfectly attractive "the effort signal says nothing" key. That is a slow leak toward exactly the collapse the detach prevents, arriving through the side door. Mask the padding.

### 4.2 The model, with the real Whisper decoder

My earlier sketch put an `lm_head` directly on the encoder output, which is not how Whisper transcribes and would never produce text. The real model runs the full encoder-decoder: the join modifies the _encoder states_, and the standard Whisper decoder then cross-attends into those effort-informed states. Teacher forcing uses the canonical shifted decoder inputs.

```python
# grace/model.py
"""Full dual-stream model: Whisper (Stream A) + EffortStream (Stream B),
joined at the encoder states by GraceJoin."""
from __future__ import annotations

import torch
import torch.nn as nn
from transformers import WhisperForConditionalGeneration
from transformers.models.whisper.modeling_whisper import shift_tokens_right

from .effort_stream import EffortStream
from .grace_join import GraceJoin


class GraceModel(nn.Module):
    def __init__(self, whisper_id: str = "openai/whisper-small",
                 z_dim: int = 64, freeze_encoder: bool = True):
        super().__init__()
        self.whisper = WhisperForConditionalGeneration.from_pretrained(
            whisper_id)
        d_model = self.whisper.config.d_model          # 768 for small
        self.effort = EffortStream(z_dim=z_dim)
        self.join = GraceJoin(d_model=d_model, effort_dim=z_dim)
        if freeze_encoder:
            for p in self.whisper.model.encoder.parameters():
                p.requires_grad_(False)

    def forward(self, batch: dict) -> tuple[torch.Tensor, dict,
                                            torch.Tensor]:
        # ---- Stream A: semantic encoding -------------------------------
        enc = self.whisper.model.encoder(
            input_features=batch["input_features"])
        h_sem = enc.last_hidden_state                  # [B, 1500, d]

        # ---- Stream B: continuous-time effort --------------------------
        eff = self.effort(batch["ode_feats"], batch["t"])

        # ---- The join: semantics reads effort, never trains it ---------
        h_joined, attn_w = self.join(
            h_sem, eff["z"], z_pad_mask=~batch["mask"])

        # ---- Standard Whisper decoder over effort-informed states ------
        labels = batch["labels"]
        decoder_input_ids = shift_tokens_right(
            labels.masked_fill(labels == -100,
                               self.whisper.config.pad_token_id),
            self.whisper.config.pad_token_id,
            self.whisper.config.decoder_start_token_id)
        dec = self.whisper.model.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=h_joined)
        logits = self.whisper.proj_out(dec.last_hidden_state)
        return logits, eff, attn_w
```

Architecturally, the join point matters. Joining at the encoder states means every decoder layer's cross-attention sees effort-informed acoustics, so the information is available at every decoding step rather than bolted on at the end. And because the decoder and `proj_out` are stock Whisper, the entire pretrained transcription capability is preserved; we changed what the decoder reads, not how it reads.

## 5. The Loss That Rewards Keeping the Stutter

### 5.1 L_effort, complete

The theory post wrote `L_effort` as a single squared error. The trainable version has three terms, per-component weights, masking, and one weighting trick that mattered more than everything else combined. It also fixes an index bug from my sketch that is worth confessing in public: `e_target` stores `[d, r, phi, tau, s]`, so the regression targets are columns `[0, 1, 3]`. Slicing `[..., :3]`, the obvious thing, silently regresses the _phoneme class index_ as if it were a continuous tension value. The model trains. The loss goes down. The tau head is garbage. Indices, not slices, when the layout is interleaved.

```python
# grace/losses.py
"""Loss functions for the dual-stream model.

e_target layout: [d, r, phi, tau, s]
  d   col 0  regression (seconds inside event)
  r   col 1  regression (repetitions / second)
  phi col 2  classification (40 phoneme classes)
  tau col 3  regression (tension proxy, [0,1])
  s   col 4  classification (3 silence classes)
"""
import torch
import torch.nn.functional as F

REG_COLS = torch.tensor([0, 1, 3])              # d, r, tau. NOT [:3].
REG_W = torch.tensor([1.0, 0.5, 2.0])           # weights for [d, r, tau]
PHI_W, SIL_W = 0.5, 1.0


def l_effort(eff: dict, e_tgt: torch.Tensor, mask: torch.Tensor,
             event_mask: torch.Tensor,
             event_boost: float = 4.0) -> torch.Tensor:
    """The preservation loss.

    eff:        dict from EffortStream: reg [B,T,3], phi [B,T,40],
                sil [B,T,3]
    e_tgt:      [B, T, 5] targets
    mask:       [B, T] valid (non-padded) frames
    event_mask: [B, T] True inside annotated dysfluency spans
    event_boost: loss multiplier inside events. THE key knob.
    """
    # Per-position weights: dysfluent spans count event_boost times
    # more than fluent spans. Without this, the 90 percent fluent
    # majority dominates the gradient and Stream B converges to a
    # smooth average that under-represents exactly the events it
    # exists to carry. With boost=4, held-silence recall went from
    # 0.61 to 0.88 in my runs. This is a class-imbalance problem
    # wearing a continuous-time costume; treat it like one.
    w = torch.where(event_mask, event_boost, 1.0) * mask    # [B, T]

    reg_tgt = e_tgt[..., REG_COLS.to(e_tgt.device)]         # [B, T, 3]
    reg_err = (eff["reg"] - reg_tgt) ** 2
    loss_reg = (reg_err * REG_W.to(reg_err)).sum(-1)        # [B, T]

    loss_phi = F.cross_entropy(
        eff["phi"].flatten(0, 1),
        e_tgt[..., 2].long().flatten(),
        reduction="none").view_as(w)

    loss_sil = F.cross_entropy(
        eff["sil"].flatten(0, 1),
        e_tgt[..., 4].long().flatten(),
        reduction="none").view_as(w)

    per_pos = loss_reg + PHI_W * loss_phi + SIL_W * loss_sil
    return (per_pos * w).sum() / w.sum().clamp(min=1.0)


def total_loss(logits: torch.Tensor, labels: torch.Tensor, eff: dict,
               e_tgt: torch.Tensor, mask: torch.Tensor,
               event_mask: torch.Tensor,
               lam: float = 0.5) -> tuple[torch.Tensor, dict]:
    """L_total = L_sem + lambda * L_effort.

    L_sem trains Stream A and the join. L_effort trains Stream B
    alone. The detach in GraceJoin guarantees the separation.
    """
    l_sem = F.cross_entropy(
        logits.flatten(0, 1), labels.flatten(), ignore_index=-100)
    l_eff = l_effort(eff, e_tgt, mask, event_mask)
    return l_sem + lam * l_eff, {
        "l_sem": float(l_sem.detach()),
        "l_eff": float(l_eff.detach()),
    }


def lam_schedule(step: int, warmup_steps: int = 4000,
                 lam_min: float = 0.1, lam_max: float = 0.5) -> float:
    """Linear warmup for lambda. Starting high destabilizes Stream A
    before the effort targets are learnable; starting at zero lets
    Stream A settle into erasure habits the join then faithfully
    reads. The warmup threads the needle."""
    frac = min(step / max(warmup_steps, 1), 1.0)
    return lam_min + frac * (lam_max - lam_min)


@torch.no_grad()
def z_norm_in_events(eff: dict, event_mask: torch.Tensor) -> float:
    """Mean ||z(t)|| inside annotated dysfluency spans.

    The canary metric for effort-stream collapse (Section 5.2). Total
    loss is dominated by L_sem and will look healthy while Stream B
    quietly dies; this number will not lie to you."""
    z = eff["z"]                                   # [B, T, z_dim]
    m = event_mask.unsqueeze(-1).to(z.dtype)       # [B, T, 1]
    denom = m.sum().clamp(min=1.0)
    return float((z.norm(dim=-1, keepdim=True) * m).sum() / denom)
```

### 5.2 The ablation that proves the detach

The strongest single piece of evidence I can offer that the severed gradient is the load-bearing element: train two otherwise identical models, one with `kv.detach()`, one without, and watch `z_norm_in_events` by epoch.

```text
||z(t)|| inside dysfluency spans, by epoch (batch-mean):

epoch     with detach     without detach
  1          2.31              2.28
  3          3.04              2.41
  5          3.22              1.67
  8          3.19              0.74
 12          3.25              0.21    <- effort stream collapsing
 16          3.21              0.06    <- erasure, relearned
```

Without the detach, the semantic stream discovers that the cheapest way to reduce total loss is to attend to an effort signal that says nothing, and it back-propagates that preference until Stream B obliges. The architecture quietly reverts to the exact pathology from the theory post, while the transcript metrics look fine the whole way down. If you take one number away from this post, take that 0.06.

## 6. The Training Script, Complete

Everything in one place: seeding, data, optimizer with sensible parameter groups, scheduler, the bf16/fp32 split, clipping, validation with the metric that actually matters, and checkpointing. This is `train.py` in full minus the argparse.

```python
# grace/train.py
"""End-to-end training for the dual-stream dysfluency model.

Run:  python -m grace.train
"""
from __future__ import annotations

import math
import random
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

from .dataset import GraceDataset, collate
from .losses import lam_schedule, total_loss, z_norm_in_events
from .model import GraceModel

# ---------------------------------------------------------------------
CFG = {
    "manifest_train": "data/manifest_train.jsonl",
    "manifest_val": "data/manifest_val.jsonl",
    "root": "data",
    "whisper_id": "openai/whisper-small",
    "batch_size": 16,
    "epochs": 20,
    "lr_new": 3e-4,          # EffortStream + GraceJoin (fresh weights)
    "lr_whisper": 1e-5,      # unfrozen Whisper parts (pretrained)
    "weight_decay": 0.01,
    "clip_norm": 1.0,
    "seed": 1861,            # Tagore's birth year. Old habits.
    "ckpt_dir": "checkpoints",
    "z_norm_alert_drop": 0.30,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}
# ---------------------------------------------------------------------


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def to_device(batch: dict, device: str) -> dict:
    return {k: (v.to(device, non_blocking=True)
                if torch.is_tensor(v) else v)
            for k, v in batch.items()}


def build_optimizer(model: GraceModel) -> torch.optim.AdamW:
    """Two parameter groups: fresh modules at a real learning rate,
    pretrained Whisper parts at a gentle one. One uniform LR either
    cooks Whisper or starves the ODE; this split is not optional."""
    fresh, pretrained = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (pretrained if name.startswith("whisper") else fresh).append(p)
    return torch.optim.AdamW(
        [{"params": fresh, "lr": CFG["lr_new"]},
         {"params": pretrained, "lr": CFG["lr_whisper"]}],
        weight_decay=CFG["weight_decay"])


@torch.no_grad()
def validate(model: GraceModel, loader: DataLoader,
             device: str) -> dict:
    """Validation focused on the metric that matters: held-silence
    F1 on annotated spans. WER on fluent references would grade the
    erasure as a success; we do not use it as the headline."""
    model.eval()
    tp = fp = fn = 0
    sem_loss_sum, n_batches = 0.0, 0
    for batch in loader:
        batch = to_device(batch, device)
        logits, eff, _ = model(batch)
        loss, parts = total_loss(
            logits, batch["labels"], eff, batch["e_tgt"],
            batch["mask"], batch["event_mask"], lam=0.5)
        sem_loss_sum += parts["l_sem"]; n_batches += 1
        pred_held = (eff["sil"].argmax(-1) == 1) & batch["mask"]
        true_held = (batch["e_tgt"][..., 4] == 1) & batch["mask"]
        tp += int((pred_held & true_held).sum())
        fp += int((pred_held & ~true_held).sum())
        fn += int((~pred_held & true_held).sum())
    prec = tp / max(tp + fp, 1)
    rec = tp / max(tp + fn, 1)
    f1 = 2 * prec * rec / max(prec + rec, 1e-8)
    model.train()
    return {"held_silence_f1": f1,
            "val_l_sem": sem_loss_sum / max(n_batches, 1)}


def main() -> None:
    set_seed(CFG["seed"])
    device = CFG["device"]
    Path(CFG["ckpt_dir"]).mkdir(exist_ok=True)

    train_ds = GraceDataset(CFG["manifest_train"], CFG["root"],
                            CFG["whisper_id"])
    val_ds = GraceDataset(CFG["manifest_val"], CFG["root"],
                          CFG["whisper_id"])
    train_loader = DataLoader(
        train_ds, batch_size=CFG["batch_size"], shuffle=True,
        collate_fn=collate, num_workers=4, pin_memory=True,
        drop_last=True)
    val_loader = DataLoader(
        val_ds, batch_size=CFG["batch_size"], shuffle=False,
        collate_fn=collate, num_workers=2)

    model = GraceModel(CFG["whisper_id"]).to(device)
    opt = build_optimizer(model)
    total_steps = CFG["epochs"] * len(train_loader)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt, T_max=total_steps)

    # NO fp16 GradScaler anywhere in this file. The Transformer runs
    # under bf16 autocast; the ODE forces itself to fp32 internally
    # (EffortStream.forward). fp16's error resolution is coarser than
    # the solver's atol, so the step-size controller chases numerical
    # noise: NFE triples and training crawls. bf16 + fp32 island only.
    step, best_f1, z_norm_trail = 0, 0.0, None
    for epoch in range(CFG["epochs"]):
        for batch in train_loader:
            batch = to_device(batch, device)
            with torch.autocast(device_type="cuda",
                                dtype=torch.bfloat16,
                                enabled=device == "cuda"):
                logits, eff, _ = model(batch)
                loss, parts = total_loss(
                    logits, batch["labels"], eff, batch["e_tgt"],
                    batch["mask"], batch["event_mask"],
                    lam=lam_schedule(step))

            loss.backward()
            # Adjoint backprop through a long trajectory occasionally
            # emits a gradient spike when the solver re-traces a stiff
            # region slightly differently backward than forward. Clip
            # and the spikes are a non-event; skip it and one spike
            # every few thousand steps slowly poisons the field.
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), CFG["clip_norm"])
            opt.step(); sched.step()
            opt.zero_grad(set_to_none=True)

            # --- health monitors -------------------------------------
            zn = z_norm_in_events(eff, batch["event_mask"])
            z_norm_trail = (zn if z_norm_trail is None
                            else 0.99 * z_norm_trail + 0.01 * zn)
            if zn < (1 - CFG["z_norm_alert_drop"]) * z_norm_trail:
                print(f"[ALERT] z_norm_events {zn:.3f} dropped >30% "
                      f"below trailing mean {z_norm_trail:.3f}. "
                      f"Check the detach. Check the join mask.")
            if step % 50 == 0:
                print(f"e{epoch} s{step} loss={float(loss):.4f} "
                      f"l_sem={parts['l_sem']:.4f} "
                      f"l_eff={parts['l_eff']:.4f} "
                      f"nfe={model.effort.last_nfe} "
                      f"z_norm={zn:.3f} lam={lam_schedule(step):.2f}")
            step += 1

        metrics = validate(model, val_loader, device)
        print(f"[val e{epoch}] {metrics}")
        if metrics["held_silence_f1"] > best_f1:
            best_f1 = metrics["held_silence_f1"]
            torch.save(
                {"model": model.state_dict(), "cfg": CFG,
                 "epoch": epoch, "metrics": metrics},
                Path(CFG["ckpt_dir"]) / "best.pt")
            print(f"[ckpt] new best held_silence_f1={best_f1:.3f}")


if __name__ == "__main__":
    main()
```

Throughput reality, for budgeting. On a single A10G: ~0.9 steps/sec at batch 16 on 10-second clips with adjoint dopri5, versus ~3.1 steps/sec for the same model with Stream B disabled. The effort stream costs roughly 3.4x, and the cost concentrates on dysfluent clips, which is the design behaving as specified: computation follows effort. If you are provisioning cluster GPUs for runs like this, my [vLLM on EKS guide](https://www.ranti.dev/blog/vllm-on-eks) covers the node-group and autoscaling side, which is identical for training boxes.

The full loop as a diagram:

```mermaid
flowchart TD
    A[Manifest + audio clips] -->|annotated spans| B[build_effort_targets]
    B --> C[GraceItem: ode_feats, t, e_target, verbatim labels]
    C -->|collate with masks| D[Batch on shared 10 ms grid]
    D --> E[BatchedInterpolant a_of_t]
    D --> F[Whisper encoder: Stream A]
    E --> G[EffortStream ODE solve: Stream B, fp32 island]
    G -->|z trajectory| H[GraceJoin with key padding mask]
    F -->|h_sem queries| H
    H -->|detached kv| I[Whisper decoder + proj_out]
    I --> J[L_sem cross-entropy vs verbatim labels]
    G -->|reg, phi, sil heads| K[L_effort with event boost]
    J --> L[total = L_sem + lambda warmup * L_effort]
    K --> L
    L -->|backward + clip 1.0| M[AdamW two param groups + cosine]
    M -->|log NFE and z_norm_events| N[Health monitors + alert]
    N -->|next batch| D
```

Note where the two losses attach. `L_sem` trains Stream A and the join. `L_effort` trains Stream B alone. The only bridge between the streams is the detached key-value path, which carries information forward and no gradient backward. The diagram is the contract from Section 0, drawn.

## 7. Serializing the Topology: From e(t) to JSON, Complete

A trained model emits an effort trajectory at solver resolution, hundreds of vectors per clip. Downstream consumers, an LLM agent above all, need something smaller and semantically chunked. The serializer compresses `e(t)` into a list of _hesitation events_ plus a live state summary, and it includes the boundary-stitching merge pass that streaming inference needs (Section 9). Design rule: the JSON must answer the only question the agent will ask, which is "is the speaker done, or still working."

```python
# grace/topology.py
"""Compress an effort trajectory into the hesitation-topology wire
format consumed by the MCP server (Section 8)."""
from __future__ import annotations

import torch

SIL_NAMES = ["not_silence", "held_silence", "terminal_silence"]
SCHEMA = "hesitation-topology/v1"


def _merge_adjacent(events: list[dict],
                    max_gap_s: float = 0.10) -> list[dict]:
    """Stitch events separated by tiny gaps.

    Needed for streaming: a block spanning a window boundary briefly
    serializes as two events. Any gap under max_gap_s between events
    holding the same phoneme is an artifact of windowing, not a real
    release-and-reblock, so merge them and keep the larger tau peak."""
    if not events:
        return events
    merged = [events[0]]
    for ev in events[1:]:
        prev = merged[-1]
        same_phi = ev["phoneme_id"] == prev["phoneme_id"]
        if same_phi and (ev["start"] - prev["end"]) <= max_gap_s:
            prev["end"] = ev["end"]
            prev["tau_peak"] = max(prev["tau_peak"], ev["tau_peak"])
            prev["recursion_rate"] = max(prev["recursion_rate"],
                                         ev["recursion_rate"])
            prev["silence"] = ev["silence"]
        else:
            merged.append(ev)
    return merged


@torch.no_grad()
def serialize_topology(eff: dict, t: torch.Tensor, mask: torch.Tensor,
                       batch_index: int = 0,
                       min_event_s: float = 0.15,
                       d_active_thresh: float = 0.05) -> dict:
    """Compress one clip's effort trajectory into discrete hesitation
    events plus a live tail-state.

    Args:
      eff:  dict from EffortStream (reg [B,T,3], phi, sil)
      t:    [T] shared timestamps (seconds)
      mask: [B, T] valid frames
      batch_index: which clip in the batch to serialize
    Returns: the wire-format dict (see example below).
    """
    m = mask[batch_index]
    reg = eff["reg"][batch_index][m].float().cpu()   # [T, 3] d, r, tau
    sil = eff["sil"][batch_index][m].argmax(-1).cpu()
    phi = eff["phi"][batch_index][m].argmax(-1).cpu()
    tt = t[: int(m.sum())].float().cpu()

    events: list[dict] = []
    cur: dict | None = None
    for i in range(len(tt)):
        active = bool(reg[i, 0] > d_active_thresh) or \
                 bool(sil[i] == 1)                   # d > eps OR held
        if active and cur is None:
            cur = {"start": round(float(tt[i]), 3), "tau_peak": 0.0,
                   "recursion_rate": 0.0, "phoneme_id": int(phi[i]),
                   "silence": SIL_NAMES[int(sil[i])],
                   "end": round(float(tt[i]), 3)}
        if active and cur is not None:
            cur["tau_peak"] = max(cur["tau_peak"],
                                  round(float(reg[i, 2]), 3))
            cur["recursion_rate"] = max(cur["recursion_rate"],
                                        round(float(reg[i, 1]), 2))
            cur["end"] = round(float(tt[i]), 3)
            cur["phoneme_id"] = int(phi[i])
            cur["silence"] = SIL_NAMES[int(sil[i])]
        elif cur is not None:
            if cur["end"] - cur["start"] >= min_event_s:
                events.append(cur)
            cur = None
    if cur is not None and cur["end"] - cur["start"] >= min_event_s:
        events.append(cur)                           # clip ended MID-event

    events = _merge_adjacent(events)
    tail_open = bool(events) and \
        events[-1]["end"] >= float(tt[-1]) - 1e-3

    return {
        "schema": SCHEMA,
        "clip_duration_s": round(float(tt[-1]), 3),
        "events": events,
        "state": {
            # The field agents act on. unresolved == the speaker is
            # mid-effort RIGHT NOW. Interrupting here is the machine
            # equivalent of finishing someone's sentence for them.
            "unresolved": tail_open,
            "current_tension": round(float(reg[-1, 2]), 3),
            "held_silence": SIL_NAMES[int(sil[-1])] == "held_silence",
        },
    }
```

Example output for a clip that ends mid-block:

```json
{
  "schema": "hesitation-topology/v1",
  "clip_duration_s": 6.21,
  "events": [
    {
      "start": 1.22,
      "end": 2.08,
      "recursion_rate": 4.6,
      "phoneme_id": 7,
      "tau_peak": 0.83,
      "silence": "not_silence"
    },
    {
      "start": 4.31,
      "end": 6.21,
      "recursion_rate": 0.0,
      "phoneme_id": 7,
      "tau_peak": 0.91,
      "silence": "held_silence"
    }
  ],
  "state": {
    "unresolved": true,
    "current_tension": 0.88,
    "held_silence": true
  }
}
```

Read the second event. Zero recursion rate, high tension, held silence: that is a block, the loaded silence. A standard VAD pipeline reports this exact acoustic situation as "speech ended." This schema reports it as "speech very much in progress, effort high, do not act." Same audio, opposite conclusion, and the difference is the entire point of the architecture.

## 8. Agentic Integration: An MCP Server That Teaches Agents to Wait

The last mile. The model is only useful if the systems that sit on top of it change their behavior, and the dominant failure of voice agents with dysfluent speakers is interruption: VAD says "done," the agent starts talking over a person who is mid-block. We fix that by exposing the topology through a [Model Context Protocol](https://modelcontextprotocol.io/?utm_source=ranti.dev) server, so any MCP-capable agent can ask "is the speaker still working" before it opens its mouth.

The server holds the latest topology per session, exposes an ingestion tool for the inference pipeline, and two read tools for agents. `get_hesitation_topology` returns the full structure for reasoning. `should_yield_turn` is the cheap fast-path the agent calls before every response. Sessions expire so a crashed pipeline cannot leave an agent waiting forever on stale state; the fail-open choice is deliberate and documented inline.

```python
# grace/mcp_server.py
"""MCP server exposing hesitation topology to LLM agents.

Run:  python -m grace.mcp_server
Configure your MCP-capable agent client to launch this as a stdio
server, or run it behind streamable HTTP for shared deployments.
"""
from __future__ import annotations

import time
from typing import Any

from mcp.server.fastmcp import FastMCP

mcp = FastMCP("hesitation-topology")

SESSIONS: dict[str, dict[str, Any]] = {}   # session_id -> entry
STALE_AFTER_S = 5.0


def _fresh(entry: dict[str, Any]) -> bool:
    return (time.monotonic() - entry["ts"]) <= STALE_AFTER_S


@mcp.tool()
def ingest_topology(session_id: str, topology: dict) -> dict:
    """Called by the inference pipeline after each audio window with
    the serialized hesitation-topology/v1 payload. Overwrites the
    session's previous state. Returns an ack with the stored schema."""
    if topology.get("schema") != "hesitation-topology/v1":
        return {"ok": False, "error": "unknown schema"}
    SESSIONS[session_id] = {"topology": topology,
                            "ts": time.monotonic()}
    return {"ok": True, "schema": topology["schema"]}


@mcp.tool()
def get_hesitation_topology(session_id: str) -> dict:
    """Full hesitation topology for the session: every dysfluency
    event (span, recursion rate, peak tension, silence class) plus
    live state. Use when reasoning about HOW the user is speaking,
    not only what they said."""
    entry = SESSIONS.get(session_id)
    if entry is None or not _fresh(entry):
        return {"schema": "hesitation-topology/v1", "events": [],
                "state": {"unresolved": False,
                          "current_tension": 0.0,
                          "held_silence": False},
                "stale": entry is not None}
    return entry["topology"]


@mcp.tool()
def should_yield_turn(session_id: str) -> dict:
    """Cheap pre-response check. Returns wait=True when the speaker
    is mid-effort (unresolved block or held silence). An agent MUST
    NOT begin speaking while wait is True. Poll until released.

    Fail-open on stale or missing state: if the pipeline stopped
    feeding us, we release the floor rather than freezing the agent
    forever. A wrongly silent agent annoys; a permanently frozen one
    is an outage."""
    entry = SESSIONS.get(session_id)
    if entry is None or not _fresh(entry):
        return {"wait": False, "reason": "no fresh topology, fail-open",
                "recheck_ms": None}
    s = entry["topology"].get("state", {})
    wait = bool(s.get("unresolved") or s.get("held_silence"))
    return {
        "wait": wait,
        "reason": ("speaker mid-block, silence is held not terminal"
                   if wait else "floor is open"),
        "tension": s.get("current_tension", 0.0),
        "recheck_ms": 250 if wait else None,
    }


if __name__ == "__main__":
    mcp.run()
```

And the consuming side, complete enough to paste into a voice agent. The turn-taking guard is a small async function, and the system prompt change is one sentence:

```python
# voice_agent/turn_gate.py
"""The agent-side gate. Call before EVERY spoken response."""
import asyncio


async def wait_for_floor(mcp_client, session_id: str,
                         max_wait_s: float = 30.0) -> bool:
    """Block until the speaker releases the floor or max_wait_s
    elapses. Returns True if the floor is open, False on timeout
    (caller decides how to degrade, e.g. a gentle check-in)."""
    deadline = asyncio.get_event_loop().time() + max_wait_s
    while True:
        verdict = await mcp_client.call_tool(
            "should_yield_turn", {"session_id": session_id})
        if not verdict["wait"]:
            return True
        if asyncio.get_event_loop().time() >= deadline:
            return False
        await asyncio.sleep(verdict["recheck_ms"] / 1000)
```

```text
# one line added to the agent system prompt
Before responding in voice sessions, call should_yield_turn. If
wait is true, the user is still speaking with effort. Do not
respond, do not summarize what they "probably meant." Wait.
```

That is the whole integration. The agent does not need to understand Neural ODEs. It needs one boolean, delivered by a tool, derived from a representation that knows the difference between a finished silence and a fighting one. If you want the deeper pattern of how agents should consume tools in a loop like this, I covered it in [What Is Agent Looping?](https://www.ranti.dev/blog/what-is-agent-looping); the topology server slots into that loop as just another tool call, which is exactly the level of boring it should be.

End-to-end data flow:

```mermaid
flowchart LR
    A[Microphone audio] -->|streaming windows| B[Feature extraction]
    B -->|BatchedInterpolant a_of_t| C[EffortStream ODE]
    B -->|Whisper mel frames| D[Whisper Stream A]
    C -->|z trajectory| E[GraceJoin]
    D -->|h_sem| E
    E -->|verbatim transcript| F[Agent context]
    C -->|e of t trajectory| G[Topology serializer + merge pass]
    G -->|hesitation-topology/v1 JSON| H[MCP server ingest_topology]
    H -->|get_hesitation_topology| I[LLM agent reasoning]
    H -->|should_yield_turn| J[Turn-taking gate wait_for_floor]
    J -->|wait=true: hold floor| K[Agent stays silent]
    J -->|wait=false: respond| L[Agent speaks]
    F --> I
```

Follow the two paths out of the model. The transcript path carries _what_ was said, stutter preserved verbatim. The topology path carries _how_ it was said and whether it is still being said. Standard pipelines have only the first path, with the stutter scrubbed out of it. The second path is what the entire theoretical argument was for.

## 9. Deployment Notes: Streaming, Latency, and Where the ODE Hurts

Training is offline; the agent use case is live. Three production realities deserve their own section, because they are where a reviewer of this design will push hardest.

**Streaming means windowed integration with carried state.** You cannot wait for a full utterance before solving. The serving loop processes audio in overlapping windows (I use 1.0-second windows with a 0.2-second overlap) and, instead of re-initializing `z0` from scratch each window, carries the final latent of window k in as the initial condition of window k+1. The ODE formulation makes this natural: the latent is a state, and a state can be resumed. The interpolant is rebuilt per window from a small ring buffer of features. Continuity across window boundaries was within noise in my checks, with one caveat: a block that spans a boundary briefly reads as two events until the serializer's `_merge_adjacent` pass stitches spans whose gap is under 100 milliseconds. That pass is in `topology.py` above, ten lines, already written.

**The latency budget works, with margin to spare.** On an A10G serving one stream: feature extraction ~4 ms per window, the windowed ODE solve 18 to 40 ms depending on stiffness (blocks cost more, as designed), serialization under 1 ms, MCP round trip on localhost ~2 ms. Worst case lands near 50 ms per window against a 200 ms turn-taking budget, so the topology path is comfortably real-time for a handful of concurrent streams per GPU. The semantic stream remains the heavy consumer. Note what the solver cost profile means operationally: your p99 latency is set by your most dysfluent moments. Provision against p99 on dysfluent audio, not against the fluent mean, or the system will be slowest exactly when its output matters most.

**Export is the genuine weak point, and I will not pretend otherwise.** Adaptive ODE solvers are data-dependent control flow, which ONNX export and most graph compilers handle badly. Your three options, in declining order of fidelity: serve the PyTorch model directly behind a thin async server, which is what I do and which `torch.compile` accepts with the solve left in eager mode; switch inference to fixed-step `rk4` with a step size validated against dopri5 outputs on a dysfluent holdout, which exports cleanly and cost me a measurable but tolerable degradation on block boundaries; or distill the trajectory into a discrete-time student model, which abandons the continuous-time substrate at the edge and should be treated as a last resort, because it quietly reintroduces the grid the architecture exists to escape. Pick consciously. The second option is the pragmatic middle for most deployments, provided the validation set is dysfluency-weighted.

One more operational note. The effort stream and the MCP server are stateless across sessions apart from the per-session latent and topology, so horizontal scaling is ordinary: shard by session ID, keep the carried latent in the worker's memory, and let your usual autoscaling logic handle the rest. Nothing about Stream B changes the deployment story you already know; it only changes what the system can hear.

## 10. What I Would Tell You Before You Build This

Numbers and judgments from my runs, stated plainly so you can calibrate.

**Where the wins were real.** Held-silence classification is the headline: the dual-stream model separates loaded silence from terminal silence with high reliability once the event-boost weighting is in (Section 5.1), and that single capability eliminates the interruption failure in the agent demo. Verbatim transcription of repetitions improved as well, which I attribute to the join giving Stream A permission, in effect, to emit what it hears instead of resolving it.

**Where the costs were real.** Training is ~3.4x slower than the Stream-A-only baseline, inference adds the ODE solve, and the data work, verbatim transcripts and span annotations above all, is the true bottleneck. The model is the easy part. It always is.

**The four mistakes to not repeat.** One: do not train against cleaned transcripts, ever; you will spend a week discovering your streams are fighting. Two: do not run the solver under fp16; you will spend a day staring at NFE counts that make no sense. Three: do not delete the detach during a refactor; you will spend an epoch budget proving the theory post's central claim by accident, and the transcript metrics will smile at you the whole way down. Four: unit-test the interpolant against a hand-computed example before anything else touches it; mine fed every clip in the batch the acoustics of clip zero for two days, and nothing in the loss curve said so.

**The honest status.** This is a working research implementation with one author and a demo agent that waits. It is not a clinical system, and the tension proxy in particular is a placeholder for proper articulatory ground truth. The evaluation protocol that would settle the bigger claim, Dysfluency Information Retention against a matched Whisper baseline, is specified in [the theory post](/blog/topography-of-hesitation), Section 6, and running it at scale is the obvious next post.

The theory post ended by saying that the capability to dwell in unresolved latency has to be built in, because it will never be optimized in. This post is the receipt. It is built. One dataset module, one ODE field, one detach, one weighted loss, one JSON schema, one MCP tool, and an agent that, for the first time in my testing, shut up and waited while a speaker fought a block to its end. Build it, break it, and send me what you find: the repo structure above is exactly how mine is laid out, and the parts I am least sure about are labeled as such in the comments. The fastest way to improve this architecture is for someone to prove a piece of it wrong in public.


---

<!-- METADATA_START -->
## Metadata & Citations

### Further Reading
- [The Topography of Hesitation: Non-Markovian Ruptures and the Mathematical Failure of Autoregressive Models on Dysfluent Speech](https://www.ranti.dev/blog/topography-of-hesitation.md)
- [Beyond RAG: Using Multi-Agent Systems for Deep Cultural and Literary Analysis](https://www.ranti.dev/blog/beyond-rag-tagore.md)
- [Building Neuroinclusive AI with Model Context Protocol (MCP)](https://www.ranti.dev/blog/neuroinclusive-mcp.md)

### Navigation
- [Back to Bio Hub](https://www.ranti.dev/.md)
- [Full Site Manifest](https://www.ranti.dev/llms.txt)

```json
{
  "@context": "https://schema.org",
  "@type": "TechArticle",
  "headline": "Implementing Grace: A PyTorch Case Study in Dual-Stream Dysfluency Models",
  "author": {
    "@type": "Person",
    "name": "Rantideb Howlader"
  },
  "datePublished": "2026-06-22T00:00:00.000Z",
  "url": "https://www.ranti.dev/blog/implementing-grace",
  "license": "https://creativecommons.org/licenses/by/4.0/",
  "isAccessibleForFree": true
}
```

### BibTeX
```bibtex
@article{implementing-grace_2026,
  author = {Rantideb Howlader},
  title = {Implementing Grace: A PyTorch Case Study in Dual-Stream Dysfluency Models},
  journal = {Rantideb Howlader Portfolio},
  year = {2026},
  url = {https://www.ranti.dev/blog/implementing-grace},
  note = {Accessed: 2026-06-24}
}
```

### IEEE
Rantideb Howlader, "Implementing Grace: A PyTorch Case Study in Dual-Stream Dysfluency Models," Rantideb Howlader Portfolio, 2026. [Online]. Available: https://www.ranti.dev/blog/implementing-grace. [Accessed: 2026-06-24].

### APA
Rantideb Howlader. (2026). Implementing Grace: A PyTorch Case Study in Dual-Stream Dysfluency Models. Rantideb Howlader. Retrieved from https://www.ranti.dev/blog/implementing-grace

--- 
*This content is provided in research-grade Markdown format. Required Attribution: Cite as Rantideb Howlader (2026).*
<!-- METADATA_END -->