目录
Knowledge Distillation Methods Comparison¶
Overview¶
Knowledge Distillation (KD) compresses large models into smaller ones while preserving performance.
Method Categories¶
1. Logit-Based Methods¶
Standard KD (Hinton et al., 2015)¶
Where: - \(p_t^T, p_s^T\): softened teacher/student outputs at temperature \(T\) - \(\alpha\): balancing coefficient
def kd_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.5):
soft_targets = F.softmax(teacher_logits / T, dim=-1)
soft_student = F.log_softmax(student_logits / T, dim=-1)
kd_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * T * T
ce_loss = F.cross_entropy(student_logits, labels)
return alpha * ce_loss + (1 - alpha) * kd_loss
DKD (Decoupled KD)¶
Separates target class knowledge distillation (TCKD) from non-target class knowledge distillation (NCKD).
2. Feature-Based Methods¶
FitNets¶
Match intermediate representations:
Where \(W_s\) is a learnable projection matrix.
Attention Transfer (AT)¶
Transfer attention maps:
3. Relation-Based Methods¶
RKD (Relational KD)¶
Transfers relationships between samples:
CRD (Contrastive Representation Distillation)¶
Uses contrastive learning objective:
4. Fisher-Based Methods¶
Fisher KD¶
Weights parameters by Fisher Information:
Fisher Information measures parameter importance.
Comparison Table¶
| Method | Type | Pros | Cons | Best For |
|---|---|---|---|---|
| Standard KD | Logit | Simple, effective | Requires same output dim | Classification |
| DKD | Logit | Better separation | More hyperparameters | Fine-grained tasks |
| FitNets | Feature | Works across architectures | Needs layer mapping | Different archs |
| AT | Feature | Preserves attention | Limited to attention-based | Transformers |
| RKD | Relation | Architecture agnostic | Batch size sensitive | Few-shot learning |
| CRD | Relation | Strong performance | Computationally heavy | Representation learning |
| Fisher KD | Parameter | Principled selection | Expensive to compute | Critical applications |
Implementation Tips¶
- Temperature Selection:
- Start with T=4, tune in range [1, 20]
-
Higher T → softer distributions → more knowledge transfer
-
Layer Mapping:
- For different architectures, use projections
-
Match layers with similar semantic roles
-
Loss Balancing:
- Use validation set to tune α
-
Task-specific vs. distillation trade-off
-
Multi-Method Combination:
total_loss = ce_loss + 0.3 * kd_loss + 0.2 * feature_loss + 0.1 * relation_loss
Resources¶
- FisherKD-Unified GitHub
- Original papers for each method
- PyTorch distillation libraries