Skip to content

Training and Sampling

Core APIs for model training and sampling with Weaver.

Overview

Weaver's training and sampling workflow consists of:

  1. Create a service client to connect to Weaver
  2. Create a training client for your chosen model
  3. Prepare training data in the correct format
  4. Train the model using forward/backward passes and optimizer steps
  5. Sample from the model to generate text

Creating Clients

ServiceClient

The ServiceClient is your main entry point to Weaver.

python
from weaver import ServiceClient

# Create with API key from environment
service_client = ServiceClient()

# Or specify API key directly
service_client = ServiceClient(api_key="your-api-key")

# Connect to custom endpoint
service_client = ServiceClient()

TrainingClient

Create a training client for a specific model:

python
training_client = service_client.create_model(
    base_model="Qwen/Qwen3-8B",
    lora_config={"rank": 32}
)

Parameters:

  • base_model (str): Model identifier (e.g., "Qwen/Qwen3-8B")
  • lora_config (dict): LoRA configuration with rank parameter

Getting the Tokenizer

Access the tokenizer through the training client:

python
tokenizer = training_client.get_tokenizer()

# Use the tokenizer
tokens = tokenizer.encode("Hello, world!", add_special_tokens=True)
text = tokenizer.decode(tokens)

Preparing Training Data

Training data in Weaver is represented as Datum objects.

Datum Structure

A Datum contains:

  • model_input: Input tokens for the model
  • loss_fn_inputs: Additional inputs for the loss function
python
from weaver import types
import torch

# Create a datum
datum = types.Datum(
    model_input=types.ModelInput.from_ints(input_tokens),
    loss_fn_inputs={
        "target_tokens": torch.tensor(target_tokens, dtype=torch.int64),
        "weights": torch.tensor(weights, dtype=torch.float32),
    }
)

Processing Examples

Here's a complete example of processing training data:

python
def process_example(prompt, completion, tokenizer):
    # Encode prompt and completion
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(completion, add_special_tokens=False)

    # Combine tokens
    tokens = prompt_tokens + completion_tokens

    # Create weights (0 for prompt, 1 for completion)
    weights = [0.0] * len(prompt_tokens) + [1.0] * len(completion_tokens)

    # Shift for next-token prediction
    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    weights = weights[1:]

    return types.Datum(
        model_input=types.ModelInput.from_ints(input_tokens),
        loss_fn_inputs={
            "target_tokens": torch.tensor(target_tokens, dtype=torch.int64),
            "weights": torch.tensor(weights, dtype=torch.float32),
        },
    )

# Process multiple examples
examples = [
    {"prompt": "Q: What is AI?", "completion": " A: Artificial Intelligence"},
    {"prompt": "Q: What is ML?", "completion": " A: Machine Learning"},
]

datums = [process_example(ex["prompt"], ex["completion"], tokenizer) for ex in examples]

Weight Masking

The weights tensor controls which tokens contribute to the loss:

  • 0.0: Token is ignored in loss calculation (e.g., prompt tokens)
  • 1.0: Token contributes fully to loss (e.g., completion tokens)
  • Values between 0-1: Partial contribution
python
# Example: Weight completion tokens more heavily
weights = [0.0] * len(prompt_tokens) + [2.0] * len(completion_tokens)

Training

forward_backward()

Computes forward pass and accumulates gradients.

python
result = training_client.forward_backward(
    datums,
    loss_fn="cross_entropy",
    wait=True,
)

Parameters:

  • datums (List[Datum]): List of training examples
  • loss_fn (str): Loss function to use ("cross_entropy" or "importance_sampling")
  • loss_fn_config (dict, optional): Configuration for the loss function
  • wait (bool): Whether to wait for completion (default: False)

Returns:

A dictionary with training metrics:

python
{
    "result": {
        "loss_fn_outputs": [...],  # Per-example outputs
        "metrics": {
            "loss": 0.5,  # Average loss
            ...
        }
    }
}

optim_step()

Updates model parameters using accumulated gradients.

python
training_client.optim_step(
    types.AdamParams(learning_rate=1e-4),
    wait=True,
)

Parameters:

  • adam_params (AdamParams): Adam optimizer parameters
  • wait (bool): Whether to wait for completion (default: False)

AdamParams

Configuration for the Adam optimizer:

python
adam_params = types.AdamParams(
    learning_rate=1e-4,  # Learning rate
    beta1=0.9,           # First moment decay
    beta2=0.999,         # Second moment decay
    eps=1e-8,            # Epsilon for numerical stability
)

Complete Training Loop

python
from weaver import ServiceClient, types
import torch

# Setup
service_client = ServiceClient()
training_client = service_client.create_model(
    base_model="Qwen/Qwen3-8B",
    lora_config={"rank": 32}
)
tokenizer = training_client.get_tokenizer()

# Prepare data
examples = [
    {"input": "hello world", "output": "ello-hay orld-way"},
    {"input": "banana split", "output": "anana-bay plit-say"},
]

def process_example(example):
    prompt = f"English: {example['input']}\nPig Latin:"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)

    tokens = prompt_tokens + completion_tokens
    weights = [0.0] * len(prompt_tokens) + [1.0] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens[:-1]),
        loss_fn_inputs={
            "target_tokens": torch.tensor(tokens[1:], dtype=torch.int64),
            "weights": torch.tensor(weights[1:], dtype=torch.float32),
        },
    )

datums = [process_example(ex) for ex in examples]

# Training loop
adam_params = types.AdamParams(learning_rate=1e-4)
for step in range(100):
    # Forward and backward
    result = training_client.forward_backward(
        datums,
        "cross_entropy",
        wait=True,
    )

    # Optimizer step
    training_client.optim_step(adam_params, wait=True)

    # Log progress
    if step % 10 == 0:
        metrics = result.get("result", {}).get("metrics", {})
        loss = metrics.get("loss", 0.0)
        print(f"Step {step}: loss={loss:.4f}")

Sampling

SamplingClient

After training, create a sampling client to generate text:

python
sampling_client = training_client.save_weights_and_get_sampling_client(
    name="my-model"
)

Or create from saved weights:

python
sampling_client = service_client.create_sampling_client(
    model_path="/path/to/weights",
    base_model="Qwen/Qwen3-8B",
)

sample()

Generate text from a prompt:

python
from weaver import types

# Prepare prompt
prompt_tokens = tokenizer.encode("Hello, ", add_special_tokens=True)
prompt = types.ModelInput.from_ints(prompt_tokens)

# Sampling parameters
sampling_params = types.SamplingParams(
    max_tokens=50,
    temperature=0.8,
    top_p=0.95,
    stop=["\n"],
)

# Sample
result = sampling_client.sample(
    prompt=prompt,
    sampling_params=sampling_params,
    num_samples=1,
)

# Get response
sequence = result["sequences"][0]
response_tokens = sequence["tokens"]
response = tokenizer.decode(response_tokens)
print(f"Response: {response}")

Parameters:

  • prompt (ModelInput): Input prompt
  • sampling_params (SamplingParams): Sampling configuration
  • num_samples (int): Number of samples to generate
  • include_prompt_logprobs (bool): Include logprobs for prompt tokens

Returns:

python
{
    "sequences": [
        {
            "tokens": [101, 102, ...],
            "text": "generated text",
            "logprobs": [-0.5, -0.3, ...],
            "stop_reason": "stop"  # or "length"
        }
    ]
}

SamplingParams

Configuration for text generation:

python
sampling_params = types.SamplingParams(
    max_tokens=100,        # Maximum tokens to generate
    temperature=1.0,       # Sampling temperature (higher = more random)
    top_p=1.0,            # Nucleus sampling parameter
    top_k=0,              # Top-k sampling (0 = disabled)
    stop=["\n", "END"],   # Stop sequences
    repetition_penalty=1.0, # Penalty for repeated tokens
)

Temperature Guide:

  • 0.0: Greedy decoding (deterministic)
  • 0.7: Focused and coherent
  • 1.0: Balanced
  • 1.5+: Creative but potentially chaotic

compute_logprobs()

Compute log probabilities for a sequence:

python
# Prepare sequence
tokens = tokenizer.encode("Hello, world!", add_special_tokens=True)
model_input = types.ModelInput.from_ints(tokens)

# Compute logprobs
logprobs = sampling_client.compute_logprobs(prompt=model_input)

# logprobs is a list: [None, -0.5, -0.3, ...]
# First element is None (no logprob for first token)
print(f"Token logprobs: {logprobs}")

This is useful for:

  • Evaluating model confidence
  • Computing perplexity
  • Ranking candidate responses

Next Steps

Weaver API Documentation