🌀

Mind The Gap! Overcoming The Training-inference Divide of Masked Diffusion Language Models

Haoyu He, Katrin Renz, Yong Cao, Andreas Geiger

18 Aug, 2025

Paper GitHub


Masked Diffusion Language Models (MDLMs), as a promising alternative to autoregressive (AR) LLMs, enable bidirectional conditioning, faster generation, and parallel decoding

However current MDLMs face two fundamental problems:

These two problems are largely overlooked by previous works however hinder MDLMs to yield effective denoising trajectories. In this work, we address both problems with two complementary methods. Our contributions are summarized as below:

💡
  1. We observe a novel and previously unknown phenomenon in MDLMs: models occasionally produce correct answers at intermediate steps but `refine' them into incorrect results. We refer to this phenomenon as Answer Backslide, which inspires us to supervise models not only on the final results but also intermediate denoising steps.
  1. We propose Masked Diffusion Policy Optimization (MDPO) to learn effective denoising trajectories without direct supervision. By exploiting the fact that MDLMs yield complete generations at every inference step, MDPO optimizes the model with policy gradients via intermediate-step rewards. Unlike prior RL fine-tuning on MDLMs, MDPO explicitly targets addressing the training-inference divide overlooked by previous works.
  1. Instead of freezing tokens based on their single-step confidence, RCR allows flexible remasking by continuously tracking the running lowest confidence over denoising trajectories. Our experiments demonstrate consistently improved performance of using RCR on both LLaDA pre-trained as well as MDPO fine-tuned models.

1. Answer BackSlide

One key difference between MDLMs and AR models is that MDLMs yield all tokens at each step while ARs yiled a token each step. This key paradigm difference drives us to look at the model outputs at intermediate steps in our early stage experiments. Surprisingly, we found a very interesting phenomenon

😲

Finding: LLaDA answers correctly at intermediate steps, while further denoising reverse correct to wrong answers.

Note that inference with MDLMs (e.g., LLaDA) alternates between predicting all masked tokens and selectively re-masking a fraction of the prediction for iterative refinement. This involves multiple iterative denoising steps informed by the model's confidence score, forming a trajectory with fewer and fewer masked tokens to be predicted, and more and more structure being revealed. Below is a live demo of how the model (LLaDA) evolves the answer with 64 denoising steps. The model is prompted with:

What is the last nonzero digit to the right of the decimal point in the decimal expansion of $\frac{137}{500}$?
Let's think step by step and output the final answer within \boxed{}.

which is a mathematical task and the model is asked to put the answer in a pre-defined structure “\boxed{}”. Ground truth answer should be “\boxed{4}”.

By dragging the step bar, you will see how the predictions of each step change. Clicking on the “prev” or “next” will also show you which tokens are remasked everytime. We use a math verifier to detect if the model yields correct answer, for all intermediate steps and the final step. As shown in the demo, in the beginning of the denoising, correct answer appears a few times, and then the answer tokens are remasked and lost in further denoising. We refer to this phenomenon as Answer Backslide

We further analyse the Answer Backslide phenomenon in more detail below

Analysis of samples where Answer Backslide occur. (Left) Maximum length of consecutive correct steps (span) in the denoising trajectory. (Middle) Number of separate correct spans in a trajectory. (Right) Heatmap of the relative position of correct steps across the denoising process.

The left and middle subfigures show that correct spans are typically short and fragmented, with trajectories often containing multiple separate correct spans instead of steadily accumulating correct steps. Such fragmentation increases the likelihood of losing correct tokens before reaching the final step. In addition, the heatmap shows that correct answers often appear surprisingly very early but tend to decay over subsequent steps instead of steadily accumulating.

These findings not only underline the necessity to address the training-inference divide for MDLMs, but also reveal that Answer Backslide provides highly informative signals of intermediate steps for improving MDLMs, which inspires us to introduce a policy gradient methods using intermediate-step rewards to optimize MDLMs to yield effective denoising trajectories.

2. Training MDLMs as a Sequential Decision-making Problem

A simple solution to address the training-inference discrepancy would be to fine-tune the model with ground-truth trajectories. However, such trajectories are inherently unavailable, as human-generated data does not capture iterative denoising paths.

Based on the Markov property diffusion poessesses, we propose to frame the problem of learning effective denoising trajectories as a sequential decision-making problem and use the resulting framework to apply reinforcement learning. Inspired by the Answer Backslide, the intermediate steps can as well provide massive informative signals for updating the model. We therefore propose Masked Diffusion Policy Optimization (MDPO) to explicitly train the mask prediction network as an RL policy. Specifically, rollouts are created by sampling answers with the policy given a prompt, and then a reward model (e.g., math verifier) is used to evaluate the completions of intermediate and final steps. These measured rewards are then used to estimate the advantage of each step for updating the policy.

In addition, we leverage group-relative advantage estimation from GRPO. Demonstration of MDPO is as below:

MDPO generates a group of answers given a query for RL rollouts. Then all completions at intermediate and final steps are verified with a reward model. Based on verified rewards, MDPO estimates the advantage of step $t$ by considering rewards of the other steps in the current rollout and step $t$ from other rollouts in the group. These estimated advantages are used for policy optimization.

We observe that MDPO matches the performance of the previous state-of-the-art (SOTA) method with 60× fewer gradient updates, while achieving average improvements of 9.6% on MATH500 and 54.2% on Countdown over SOTA when trained within the same number of weight updates.

3. Flexible Remasking Enables MDLMs to Correct Early Mistakes

For remasking tokens at intermediate steps, LLaDA propose random remasking and low-confidence remasking. Empirically they find that the confidence-based remasking strategy works better. However, they have one crucial limitation: both remasking strategies (used in LLaDA) assign masking scores only to predicted tokens at current step. The unmasked tokens remain fixed until the end of the denoising process. We consider this a crucial limitation, as the predicted tokens, particularly in the early steps, tend to be highly noisy due to the limited structure revealed at that stage. Freezing these noisy tokens until the end of the denoising process makes it more difficult to produce high-quality generation in practice.

To address this, we makes a simple yet effective change to the Low-Confidence Remasking (LCR) to Running Confidence Remasking (RCR). Instead of deciding based solely on the confidence at the current step, we track for each position the highest confidence it has achieved so far for predicting masked tokens along the denoising process. At each step, we identify the specific amount of positions whose running maximum confidence is the lowest and remask tokens at these positions. See the figure below for the demonstration.

Comparison between Low-Confidence Remasking (LCR) and our proposed Running Confidence Remasking (RCR) during iterative denoising. For each step, we show tokens that are \emph{not} remasked. LCR freezes low-confidence tokens once unmasked, preventing further refinement, which potentially accumulates early-stage noise. For example, the token `problem' predicted and frozen by LCR at step 1 is wrong but maintained until the end of the denoising, which leads to the final wrong answer. Whereas RCR tracks the running maximum confidence for each position, allowing persistently low-confidence tokens to be refined in later steps, leading to higher-quality completions.

Empirically, with more structure being revealed along with denoising, earlier steps often yield low-confidence predictions, whereas later steps tend to converge to higher confidence. Under the LCR strategy, early tokens are retained despite their relatively low confidence if they happen to fall within the top-n set to be `kept' at that step and can not be refined in future steps. In contrast, as the above figure shows, RCR allows such tokens to be remasked in later steps if tokens at other positions surpass them in running confidence, enabling the model to revise uncertain predictions before producing the final output.

4. Results

We conduct our experiments on two reasoning tasks: MATH-500 and Countdown. A key factor to affect the generation performance of MDLMs is whether denoising is performed over the full sequence at once or in blocks. Previous works apply a semi-autoregressive strategy where the sequence is divided into several blocks and generated from left to right. Within each block, tokens are denoised in parallel. We refer to this setting as semi-AR, and contrast it with the setting of denoising all tokens in the entire sequence simultaneously (pure-Diff).

Model performance on Mathematics and CountdownBest and second-best methods in each setting are shaded. For each task we report results on semi-AR (Block Size=128) and pure-Diff (Block Size=512) given the generation length of 512. We also compare the performance across multiple choices of denoising steps.

We show that across all configurations, both MDPO and RCR individually improve substantially upon the LLaDA initialization, with RCR often achieving performance comparable to MDPO despite requiring no additional training. We remark that RCR, as a training-free method, even outperforms the training baselines in most of the settings for the MATH task. Notably, combining MDPO with RCR consistently yields further gains over either method alone, achieving the best or second-best performance in nearly all settings, which demonstrates that MDPO and RCR are complementary. We further observe that the relative performance gains are more pronounced in settings with fewer inference steps, indicating improved sampling efficiency.

4.1 Answer Backslide is An Effective Data Filter

One surprising observation is that training the model on only Answer Backslide samples with MDPO yields the best results. Specifically, we first use the initialized model to run inference on the whole dataset to identify Answer Backslide samples and then train only on this subset, which constitutes roughly 10% of the original data. The comparison between MDPO-all-data and MDPO in the following figure shows that the model trained with only Answer Backslide excels in most settings, highlighting the effectiveness of Answer Backslide as a data filter for MDPO.

Comparison of MDPO variants. We report final accuracy and proportion of Answer Backslide cases for all models. MDPO-all-data is trained on all data samples whereas MDPO is trained on only Answer Backslide data, both with rollouts sampled from a mixture of semi-AR and pure-Diff. MDPO-pure-Diff and MDPO-semi-AR represent models that are trained on rollouts sampled from only pure-Diff, and semi-AR, respectively.

4.2 Impact of Sampling Settings on MDPO

Another question we want to investigate is how different sampling settings during the rollout collection of RL affect the final performance. To investigate this, we compare MDPO variants trained on (i) pure-Diff rollouts only, (ii) semi-AR rollouts only, and (iii) an even mixture of both. The above figure shows that training on a single mode yields the largest improvement in that mode’s evaluation setting, but often at the cost of performance in the other mode. The mixture strategy used in our main MDPO setting achieves a more balanced performance, matching or exceeding the best single-mode results in several configurations. This suggests that mixed sampling allows the policy to learn denoising behaviors that generalize across all inference strategies.

Citation

If you find this blog or our codebase useful, please consider citing:

@misc{He2025MDPO,
  title={MDPO: Overcoming the Training-Inference Divide of Masked Diffusion Language Models},
  author={Haoyu He and Katrin Renz and Yong Cao and Andreas Geiger},
  year={2025},
  eprint={2508.13148},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2508.13148},}

Acknowledgements

We thank Zehao Yu, Markus Flicke, and Madhav Iyengar for fruitful discussions in the preparation of the draft. We also thank the International Max Planck Research School for Intelligent Systems (IMPRS-IS) for supporting K. Renz.