Loss
-
T2I损失 (loss_t2i) - 文本到图像生成
-
LM损失 (loss_lm) - 语言建模
-
MMU损失 (loss_mmu) - 多模态理解
-
图像编辑损失 (loss_image_editing) - 图像编辑任务
-
掩码预测损失 (loss_prediction) - 一致性预测
先掩码再算交叉熵
Remask
对输入的tokens按照取决于时间步t的比例随机跳跃式盖上mask,预测mask的token,计算这部分token的loss
初始全 mask 或部分 mask,一次可以并行预测多个 token,每轮迭代根据置信度替换低置信度 token
解码过程:统计一共要解码多少token,提取图像部分token,Classifier-Free Guidance,计算置信度,
下面把那一大段话用大白话说清楚(忽略里面“Ran tool / Search files”这类日志):
一句话
把整张图的 token 一起预测,每一轮留下更有把握的,把没把握的再盖回 [MASK]
继续猜,迭代多轮直到全都靠谱为止。
-
模型全局预测:对当前还是
[MASK]
的位置给出概率分布。 -
数多少个要定下来:
-
要么平均分配(比如 1024 个 token、18 轮 → 前 16 轮每次定 57 个,后 2 轮每次定 56 个);
-
要么用调度函数(如余弦)决定这一轮大概定多少/重盖多少。
-
-
挑谁留下:算每个位置的置信度(被选中 id 的概率),取 Top-K 个置信度最高的定下来; 其余置信度低的用规则(
mask_by_random_topk
+ 温度退火 + Gumbel 噪声)重新盖回[MASK]
。 -
边界保护:每轮至少定一个、也至少留一个待定,保证还能继续迭代。
-
已知不动:输入里原本就给定的 token 一律不参与重盖(用
unknown_map
锁死)。
文本处理
编码(Encode):
文本 → 去掉特殊符号后正则化 → 分词 → token → token ID → 张量。
解码(Decode):
token ID → token → 拼接还原文本(去掉特殊符号)。
图片处理
原始图像: [B, 3, H, W]
resize: [B, 3, 256, 256]
VQ encode: [B, 32, 32]
扁平化: [B, 1024]
词表偏移: [B, 1024](ID 范围不同)