2026/4/18 15:32:07
网站建设
项目流程
长沙优化推广外包,seo薪酬如何,网站字体大小,对网站建设的讲话tf.keras.losses.SparseCategoricalCrossentropy 核心原理
SparseCategoricalCrossentropy#xff08;稀疏类别交叉熵#xff09;是 TensorFlow/Keras 中针对多分类任务的损失函数#xff0c;专为稀疏标签#xff08;整数型标签#xff0c;如 0,1,2#xff09;设计#…tf.keras.losses.SparseCategoricalCrossentropy核心原理SparseCategoricalCrossentropy稀疏类别交叉熵是 TensorFlow/Keras 中针对多分类任务的损失函数专为稀疏标签整数型标签如0,1,2设计核心作用是衡量模型输出的类别概率分布与真实稀疏标签的「差异」本质是交叉熵Cross-Entropy在稀疏标签场景下的优化实现。一、先理解核心背景交叉熵的本质交叉熵源于信息论用于衡量两个概率分布的「距离」差异程度。对于多分类任务真实标签的分布是「one-hot 分布」比如 3 分类中标签为 1对应分布是[0,1,0]模型输出是类别概率分布经 Softmax 归一化后和为 1如[0.1,0.8,0.1]。交叉熵的公式为H(p,q)−∑i1Cp(i)log(q(i)) H(p,q) -\sum_{i1}^C p(i) \log(q(i))H(p,q)−i1∑Cp(i)log(q(i))其中ppp真实标签的概率分布one-hot 形式仅目标类别为 1其余为 0qqq模型预测的类别概率分布CCC类别总数。由于ppp是 one-hot 分布交叉熵可简化为仅取目标类别对应的预测概率的负对数因为其他项都是0×log(q(i))00 \times \log(q(i))00×log(q(i))0。二、SparseCategoricalCrossentropy 的核心适配稀疏标签普通的CategoricalCrossentropy要求标签是one-hot 编码如 3 分类标签 1 对应[0,1,0]而SparseCategoricalCrossentropy直接支持整数型稀疏标签如 1无需手动 one-hot 编码核心优势是节省内存尤其是类别数多的场景比如 1000 类时稀疏标签仅存 1 个整数one-hot 需存 1000 维向量。三、完整计算逻辑分两种场景SparseCategoricalCrossentropy的关键参数是from_logits默认False决定模型输出是否为「原始 logits未归一化的得分」或「Softmax 归一化后的概率」两种场景的计算逻辑不同TensorFlow 内部做了优化避免数值不稳定。场景 1from_logitsFalse默认模型输出是 Softmax 概率假设类别数C3C3C3真实稀疏标签y1y1y1对应目标类别是第 2 类索引从 0 开始模型输出 Softmax 概率q[0.1,0.8,0.1]q[0.1, 0.8, 0.1]q[0.1,0.8,0.1]。计算步骤取真实标签对应的概率q(y)q(1)0.8q(y)q(1)0.8q(y)q(1)0.8计算负对数−log(q(y))−log(0.8)≈0.223-\log(q(y)) -\log(0.8) ≈ 0.223−log(q(y))−log(0.8)≈0.223最终损失值即为该结果批量数据会取均值/求和由reduction参数控制。公式简化为loss−log(q(y)) \text{loss} -\log(q(y))loss−log(q(y))场景 2from_logitsTrue模型输出是原始 logits推荐模型输出的是未经过 Softmax 归一化的原始得分logits如z[1.0,3.0,0.5]z[1.0, 3.0, 0.5]z[1.0,3.0,0.5]此时 TensorFlow 不会先单独计算 Softmax避免数值下溢/上溢而是直接用log_softmax优化计算对 logits 计算log_softmaxlog(Softmax(z))z−log(∑i1Cezi)\log(\text{Softmax}(z)) z - \log(\sum_{i1}^C e^{z_i})log(Softmax(z))z−log(∑i1Cezi)取真实标签对应的项取负数即为损失loss−(zy−log(∑i1Cezi)) \text{loss} - \left( z_y - \log(\sum_{i1}^C e^{z_i}) \right)loss−(zy−log(i1∑Cezi))示例计算z[1.0,3.0,0.5],y1z[1.0, 3.0, 0.5], y1z[1.0,3.0,0.5],y1先算∑ezie1.0e3.0e0.5≈2.71820.0851.648≈24.451\sum e^{z_i} e^{1.0} e^{3.0} e^{0.5} ≈ 2.718 20.085 1.648 ≈ 24.451∑ezie1.0e3.0e0.5≈2.71820.0851.648≈24.451log(24.451)≈3.200\log(24.451) ≈ 3.200log(24.451)≈3.200log(Softmax(z))13.0−3.200−0.200\log(\text{Softmax}(z))_1 3.0 - 3.200 -0.200log(Softmax(z))13.0−3.200−0.200损失值−(−0.200)0.200-(-0.200) 0.200−(−0.200)0.200。为什么推荐from_logitsTrueSoftmax 对大 logits 会产生e大值e^{大值}e大值如e100e^{100}e100溢出而log_softmax直接通过代数变换避免了单独计算 Softmax提升数值稳定性。四、批量数据的损失归约实际训练中输入是批量数据batch损失会通过reduction参数归约默认AUTO等价于SUM_OVER_BATCH_SIZE对每个样本计算损失值求批量内所有样本损失的均值或求和取决于reduction。示例batch_size2样本稀疏标签模型概率单样本损失11[0.1,0.8,0.1]0.22320[0.9,0.05,0.05]0.105批量损失 (0.223 0.105) / 2 ≈ 0.164。五、关键参数解析参数作用示例from_logits是否输入为原始 logits非 Softmax 概率from_logitsTrue推荐reduction损失归约方式-NONE返回每个样本的损失-SUM批量损失求和-SUM_OVER_BATCH_SIZE批量损失求均值reductionsum_over_batch_sizeignore_index忽略指定标签计算损失时跳过适用于样本标注缺失场景ignore_index-1axis类别维度默认 -1即最后一维是类别模型输出形状(batch, 3)时axis-1 对应 3 个类别六、与CategoricalCrossentropy的对比特性SparseCategoricalCrossentropyCategoricalCrossentropy标签格式整数型稀疏标签如 1,2,3one-hot 编码标签如 [0,1,0]内存占用低仅存整数高类别数维向量适用场景类别数多、标签天然为整数如图像分类的类别索引标签已做 one-hot 编码核心公式同交叉熵但直接取整数标签对应项交叉熵原始公式遍历所有类别七、注意事项标签范围稀疏标签必须是[0,C−1][0, C-1][0,C−1]范围内的整数C 是类别数否则会报错数值稳定性优先设置from_logitsTrue避免 Softmax 导致的数值溢出多标签任务该损失适用于「单标签多分类」每个样本仅属于一个类别多标签任务需用BinaryCrossentropy。示例代码验证importtensorflowastf# 1. 定义损失函数from_logitsTrue模型输出logitsloss_fntf.keras.losses.SparseCategoricalCrossentropy(from_logitsTrue)# 2. 模拟批量数据batch_size2类别数3y_truetf.constant([1,0])# 稀疏标签y_pred_logitstf.constant([[1.0,3.0,0.5],[5.0,1.0,0.1]])# 模型输出logits# 3. 计算损失lossloss_fn(y_true,y_pred_logits)print(批量损失值,loss.numpy())# 输出约 0.15手动计算验证综上SparseCategoricalCrossentropy本质是「多分类交叉熵」在稀疏标签下的高效实现核心是通过直接索引整数标签避免 one-hot 编码同时优化数值计算保证稳定性是单标签多分类任务的首选损失函数之一。