Loss

  1. T2I损失 (loss_t2i) - 文本到图像生成

  2. LM损失 (loss_lm) - 语言建模

  3. MMU损失 (loss_mmu) - 多模态理解

  4. 图像编辑损失 (loss_image_editing) - 图像编辑任务

  5. 掩码预测损失 (loss_prediction) - 一致性预测

先掩码再算交叉熵

Remask

对输入的tokens按照取决于时间步t的比例随机跳跃式盖上mask,预测mask的token,计算这部分token的loss

初始全 mask 或部分 mask,一次可以并行预测多个 token,每轮迭代根据置信度替换低置信度 token

解码过程:统计一共要解码多少token,提取图像部分token,Classifier-Free Guidance,计算置信度,

下面把那一大段话用大白话说清楚(忽略里面“Ran tool / Search files”这类日志):

一句话

把整张图的 token 一起预测,每一轮留下更有把握的,把没把握的再盖回 [MASK] 继续猜,迭代多轮直到全都靠谱为止。

  1. 模型全局预测:对当前还是 [MASK] 的位置给出概率分布。

  2. 数多少个要定下来

    1. 要么平均分配(比如 1024 个 token、18 轮 → 前 16 轮每次定 57 个,后 2 轮每次定 56 个);

    2. 要么用调度函数(如余弦)决定这一轮大概定多少/重盖多少。

  3. 挑谁留下:算每个位置的置信度(被选中 id 的概率),取 Top-K 个置信度最高的定下来; 其余置信度低的用规则(mask_by_random_topk + 温度退火 + Gumbel 噪声)重新盖回 [MASK]

  4. 边界保护:每轮至少定一个、也至少留一个待定,保证还能继续迭代。

  5. 已知不动:输入里原本就给定的 token 一律不参与重盖(用 unknown_map 锁死)。

文本处理

编码(Encode):

文本 → 去掉特殊符号后正则化 → 分词 → token → token ID → 张量。

解码(Decode):

token ID → token → 拼接还原文本(去掉特殊符号)。

图片处理

原始图像: [B, 3, H, W]

resize: [B, 3, 256, 256]

VQ encode: [B, 32, 32]

扁平化: [B, 1024]

词表偏移: [B, 1024](ID 范围不同)