Appearance
Training and Sampling
Core APIs for model training and sampling with Weaver.
Overview
Weaver's training and sampling workflow consists of:
- Create a service client to connect to Weaver
- Create a training client for your chosen model
- Prepare training data in the correct format
- Train the model using forward/backward passes and optimizer steps
- Sample from the model to generate text
Creating Clients
ServiceClient
The ServiceClient is your main entry point to Weaver.
python
from weaver import ServiceClient
# Create with API key from environment
service_client = ServiceClient()
# Or specify API key directly
service_client = ServiceClient(api_key="your-api-key")
# Connect to custom endpoint
service_client = ServiceClient()TrainingClient
Create a training client for a specific model:
python
training_client = service_client.create_model(
base_model="Qwen/Qwen3-8B",
lora_config={"rank": 32}
)Parameters:
base_model(str): Model identifier (e.g., "Qwen/Qwen3-8B")lora_config(dict): LoRA configuration withrankparameter
Getting the Tokenizer
Access the tokenizer through the training client:
python
tokenizer = training_client.get_tokenizer()
# Use the tokenizer
tokens = tokenizer.encode("Hello, world!", add_special_tokens=True)
text = tokenizer.decode(tokens)Preparing Training Data
Training data in Weaver is represented as Datum objects.
Datum Structure
A Datum contains:
model_input: Input tokens for the modelloss_fn_inputs: Additional inputs for the loss function
python
from weaver import types
import torch
# Create a datum
datum = types.Datum(
model_input=types.ModelInput.from_ints(input_tokens),
loss_fn_inputs={
"target_tokens": torch.tensor(target_tokens, dtype=torch.int64),
"weights": torch.tensor(weights, dtype=torch.float32),
}
)Processing Examples
Here's a complete example of processing training data:
python
def process_example(prompt, completion, tokenizer):
# Encode prompt and completion
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
completion_tokens = tokenizer.encode(completion, add_special_tokens=False)
# Combine tokens
tokens = prompt_tokens + completion_tokens
# Create weights (0 for prompt, 1 for completion)
weights = [0.0] * len(prompt_tokens) + [1.0] * len(completion_tokens)
# Shift for next-token prediction
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
weights = weights[1:]
return types.Datum(
model_input=types.ModelInput.from_ints(input_tokens),
loss_fn_inputs={
"target_tokens": torch.tensor(target_tokens, dtype=torch.int64),
"weights": torch.tensor(weights, dtype=torch.float32),
},
)
# Process multiple examples
examples = [
{"prompt": "Q: What is AI?", "completion": " A: Artificial Intelligence"},
{"prompt": "Q: What is ML?", "completion": " A: Machine Learning"},
]
datums = [process_example(ex["prompt"], ex["completion"], tokenizer) for ex in examples]Weight Masking
The weights tensor controls which tokens contribute to the loss:
0.0: Token is ignored in loss calculation (e.g., prompt tokens)1.0: Token contributes fully to loss (e.g., completion tokens)- Values between 0-1: Partial contribution
python
# Example: Weight completion tokens more heavily
weights = [0.0] * len(prompt_tokens) + [2.0] * len(completion_tokens)Training
forward_backward()
Computes forward pass and accumulates gradients.
python
result = training_client.forward_backward(
datums,
loss_fn="cross_entropy",
wait=True,
)Parameters:
datums(List[Datum]): List of training examplesloss_fn(str): Loss function to use ("cross_entropy" or "importance_sampling")loss_fn_config(dict, optional): Configuration for the loss functionwait(bool): Whether to wait for completion (default: False)
Returns:
A dictionary with training metrics:
python
{
"result": {
"loss_fn_outputs": [...], # Per-example outputs
"metrics": {
"loss": 0.5, # Average loss
...
}
}
}optim_step()
Updates model parameters using accumulated gradients.
python
training_client.optim_step(
types.AdamParams(learning_rate=1e-4),
wait=True,
)Parameters:
adam_params(AdamParams): Adam optimizer parameterswait(bool): Whether to wait for completion (default: False)
AdamParams
Configuration for the Adam optimizer:
python
adam_params = types.AdamParams(
learning_rate=1e-4, # Learning rate
beta1=0.9, # First moment decay
beta2=0.999, # Second moment decay
eps=1e-8, # Epsilon for numerical stability
)Complete Training Loop
python
from weaver import ServiceClient, types
import torch
# Setup
service_client = ServiceClient()
training_client = service_client.create_model(
base_model="Qwen/Qwen3-8B",
lora_config={"rank": 32}
)
tokenizer = training_client.get_tokenizer()
# Prepare data
examples = [
{"input": "hello world", "output": "ello-hay orld-way"},
{"input": "banana split", "output": "anana-bay plit-say"},
]
def process_example(example):
prompt = f"English: {example['input']}\nPig Latin:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
tokens = prompt_tokens + completion_tokens
weights = [0.0] * len(prompt_tokens) + [1.0] * len(completion_tokens)
return types.Datum(
model_input=types.ModelInput.from_ints(tokens[:-1]),
loss_fn_inputs={
"target_tokens": torch.tensor(tokens[1:], dtype=torch.int64),
"weights": torch.tensor(weights[1:], dtype=torch.float32),
},
)
datums = [process_example(ex) for ex in examples]
# Training loop
adam_params = types.AdamParams(learning_rate=1e-4)
for step in range(100):
# Forward and backward
result = training_client.forward_backward(
datums,
"cross_entropy",
wait=True,
)
# Optimizer step
training_client.optim_step(adam_params, wait=True)
# Log progress
if step % 10 == 0:
metrics = result.get("result", {}).get("metrics", {})
loss = metrics.get("loss", 0.0)
print(f"Step {step}: loss={loss:.4f}")Sampling
SamplingClient
After training, create a sampling client to generate text:
python
sampling_client = training_client.save_weights_and_get_sampling_client(
name="my-model"
)Or create from saved weights:
python
sampling_client = service_client.create_sampling_client(
model_path="/path/to/weights",
base_model="Qwen/Qwen3-8B",
)sample()
Generate text from a prompt:
python
from weaver import types
# Prepare prompt
prompt_tokens = tokenizer.encode("Hello, ", add_special_tokens=True)
prompt = types.ModelInput.from_ints(prompt_tokens)
# Sampling parameters
sampling_params = types.SamplingParams(
max_tokens=50,
temperature=0.8,
top_p=0.95,
stop=["\n"],
)
# Sample
result = sampling_client.sample(
prompt=prompt,
sampling_params=sampling_params,
num_samples=1,
)
# Get response
sequence = result["sequences"][0]
response_tokens = sequence["tokens"]
response = tokenizer.decode(response_tokens)
print(f"Response: {response}")Parameters:
prompt(ModelInput): Input promptsampling_params(SamplingParams): Sampling configurationnum_samples(int): Number of samples to generateinclude_prompt_logprobs(bool): Include logprobs for prompt tokens
Returns:
python
{
"sequences": [
{
"tokens": [101, 102, ...],
"text": "generated text",
"logprobs": [-0.5, -0.3, ...],
"stop_reason": "stop" # or "length"
}
]
}SamplingParams
Configuration for text generation:
python
sampling_params = types.SamplingParams(
max_tokens=100, # Maximum tokens to generate
temperature=1.0, # Sampling temperature (higher = more random)
top_p=1.0, # Nucleus sampling parameter
top_k=0, # Top-k sampling (0 = disabled)
stop=["\n", "END"], # Stop sequences
repetition_penalty=1.0, # Penalty for repeated tokens
)Temperature Guide:
0.0: Greedy decoding (deterministic)0.7: Focused and coherent1.0: Balanced1.5+: Creative but potentially chaotic
compute_logprobs()
Compute log probabilities for a sequence:
python
# Prepare sequence
tokens = tokenizer.encode("Hello, world!", add_special_tokens=True)
model_input = types.ModelInput.from_ints(tokens)
# Compute logprobs
logprobs = sampling_client.compute_logprobs(prompt=model_input)
# logprobs is a list: [None, -0.5, -0.3, ...]
# First element is None (no logprob for first token)
print(f"Token logprobs: {logprobs}")This is useful for:
- Evaluating model confidence
- Computing perplexity
- Ranking candidate responses
Next Steps
- Explore Loss Functions - Available loss functions
- Learn about Saving and Loading - Model persistence
- Check Model Lineup - Supported models