Appearance
Loss Functions
Weaver provides loss functions for different training objectives.
Available Loss Functions
Weaver currently supports two loss functions:
- cross_entropy: Standard supervised learning loss
- importance_sampling: Weighted loss for reinforcement learning and policy optimization
Cross Entropy Loss
Standard cross-entropy loss for supervised learning tasks.
Loss Function Inputs
For cross_entropy, each datum requires:
- target_tokens (Tensor[int64]): Target token IDs for next-token prediction
- weights (Tensor[float32]): Weight for each token in the loss
How It Works
The cross-entropy loss computes:
Where:
- is the weight for token
- is the target token
- is the predicted probability
Reference Implementation
python
import torch
import torch.nn.functional as F
def cross_entropy_loss(logits, target_tokens, weights):
"""
Compute weighted cross-entropy loss.
Args:
logits: (batch_size, vocab_size) - Model output logits
target_tokens: (batch_size,) - Target token IDs
weights: (batch_size,) - Weight for each token
Returns:
loss: Scalar loss value
"""
# Compute log probabilities
log_probs = F.log_softmax(logits, dim=-1)
# Gather log probabilities for target tokens
target_log_probs = log_probs.gather(dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)
# Apply weights and compute loss
weighted_log_probs = target_log_probs * weights
loss = -weighted_log_probs.sum() / weights.sum()
return lossImportance Sampling Loss
Weighted loss function for reinforcement learning and policy optimization.
Loss Function Inputs
For importance_sampling, each datum requires:
- target_tokens (Tensor[int64]): Target token IDs
- weights (Tensor[float32]): Base weight for each token
- importance_weights (Tensor[float32]): Importance sampling weights (e.g., advantages or rewards)
Loss Configuration
Optional configuration for importance_sampling:
python
loss_fn_config = {
"clip_ratio": 1.0, # Clip importance weights to prevent large updates
}How It Works
The importance sampling loss computes:
Where:
- is the base weight
- is the importance weight (reward/advantage)
- is the target token
Reference Implementation
python
import torch
import torch.nn.functional as F
def importance_sampling_loss(logits, target_tokens, weights, importance_weights, clip_ratio=None):
"""
Compute importance sampling loss with optional clipping.
Args:
logits: (batch_size, vocab_size) - Model output logits
target_tokens: (batch_size,) - Target token IDs
weights: (batch_size,) - Base weight for each token
importance_weights: (batch_size,) - Importance sampling weights (rewards/advantages)
clip_ratio: Optional clipping value for importance weights
Returns:
loss: Scalar loss value
"""
# Compute log probabilities
log_probs = F.log_softmax(logits, dim=-1)
# Gather log probabilities for target tokens
target_log_probs = log_probs.gather(dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)
# Apply optional clipping to importance weights
if clip_ratio is not None:
importance_weights = torch.clamp(importance_weights, -clip_ratio, clip_ratio)
# Apply both base weights and importance weights
weighted_log_probs = target_log_probs * weights * importance_weights
loss = -weighted_log_probs.sum() / weights.sum()
return lossNext Steps
- Learn about Training and Sampling - Core APIs
- Understand Saving and Loading - Model persistence
- Check Model Lineup - Supported models