Skip to content

Loss Functions

Weaver provides loss functions for different training objectives.

Available Loss Functions

Weaver currently supports two loss functions:

  • cross_entropy: Standard supervised learning loss
  • importance_sampling: Weighted loss for reinforcement learning and policy optimization

Cross Entropy Loss

Standard cross-entropy loss for supervised learning tasks.

Loss Function Inputs

For cross_entropy, each datum requires:

  • target_tokens (Tensor[int64]): Target token IDs for next-token prediction
  • weights (Tensor[float32]): Weight for each token in the loss

How It Works

The cross-entropy loss computes:

Where:

  • is the weight for token
  • is the target token
  • is the predicted probability

Reference Implementation

python
import torch
import torch.nn.functional as F

def cross_entropy_loss(logits, target_tokens, weights):
    """
    Compute weighted cross-entropy loss.

    Args:
        logits: (batch_size, vocab_size) - Model output logits
        target_tokens: (batch_size,) - Target token IDs
        weights: (batch_size,) - Weight for each token

    Returns:
        loss: Scalar loss value
    """
    # Compute log probabilities
    log_probs = F.log_softmax(logits, dim=-1)

    # Gather log probabilities for target tokens
    target_log_probs = log_probs.gather(dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)

    # Apply weights and compute loss
    weighted_log_probs = target_log_probs * weights
    loss = -weighted_log_probs.sum() / weights.sum()

    return loss

Importance Sampling Loss

Weighted loss function for reinforcement learning and policy optimization.

Loss Function Inputs

For importance_sampling, each datum requires:

  • target_tokens (Tensor[int64]): Target token IDs
  • weights (Tensor[float32]): Base weight for each token
  • importance_weights (Tensor[float32]): Importance sampling weights (e.g., advantages or rewards)

Loss Configuration

Optional configuration for importance_sampling:

python
loss_fn_config = {
    "clip_ratio": 1.0,  # Clip importance weights to prevent large updates
}

How It Works

The importance sampling loss computes:

Where:

  • is the base weight
  • is the importance weight (reward/advantage)
  • is the target token

Reference Implementation

python
import torch
import torch.nn.functional as F

def importance_sampling_loss(logits, target_tokens, weights, importance_weights, clip_ratio=None):
    """
    Compute importance sampling loss with optional clipping.

    Args:
        logits: (batch_size, vocab_size) - Model output logits
        target_tokens: (batch_size,) - Target token IDs
        weights: (batch_size,) - Base weight for each token
        importance_weights: (batch_size,) - Importance sampling weights (rewards/advantages)
        clip_ratio: Optional clipping value for importance weights

    Returns:
        loss: Scalar loss value
    """
    # Compute log probabilities
    log_probs = F.log_softmax(logits, dim=-1)

    # Gather log probabilities for target tokens
    target_log_probs = log_probs.gather(dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)

    # Apply optional clipping to importance weights
    if clip_ratio is not None:
        importance_weights = torch.clamp(importance_weights, -clip_ratio, clip_ratio)

    # Apply both base weights and importance weights
    weighted_log_probs = target_log_probs * weights * importance_weights
    loss = -weighted_log_probs.sum() / weights.sum()

    return loss

Next Steps

Weaver API Documentation