This guide will help you get started with pico-analyze, a framework for analyzing language models during their training process.
Overview
Introduction
pico-analyze is a lightweight, flexible toolkit for studying how language models evolve over the course of training. Designed as a companion to pico-train, it reads the structured checkpoints your training jobs produce—covering model weights, gradients, and optional activations—and provides a rich set of metrics and utilities to track changes at every step of your model's learning process.
Recent advances in mechanistic interpretability have highlighted the value of examining model internals beyond the final checkpoint. Traditional approaches often focus solely on end-of-training analysis, missing the rich developmental trajectory models follow. pico-analyze addresses this by integrating neatly with the training loop, logging key tensors and metrics in situ. This approach, which we call developmental interpretability, makes it easy to see exactly where models begin to represent certain linguistic features, how representations stabilize (or oscillate) over training, and which parts of the network converge faster than others.
pico-analyze allows you to:
- Track how different components of your model change throughout training
- Apply various metrics to measure properties like sparsity, effective rank, and similarity between checkpoints
- Compare model states across training steps
- Visualize learning dynamics through beautiful logs and graphs
- Analyze both simple components (individual layers) and compound components (specialized circuits)
Abstractions
At its core, pico-analyze follows a simple abstraction: it applies metrics to components. Metrics provide quantitative insights into various aspects of model behavior, while components define the specific model elements being analyzed. This design allows for flexible and fine-grained analysis of training dynamics.
Metrics
Out of the box, pico-analyze supports a range of built-in metrics, including:
- Sparsity Measures: Gini coefficient and Hoyer sparsity gauge how concentrated the values of a tensor are around zero.
- Rank-Based Metrics: PER captures a matrix's "effective dimensionality," while Condition Number evaluates its numerical stability.
- Representation Similarity: Centered Kernel Alignment (CKA) and Projection Weighted Canonical Correlation Analysis (PWCCA) compare activation patterns across layers or checkpoints, revealing how internal representations evolve.
- Norms: Frobenius, Nuclear, and Infinity norms measure the scale of a tensor, spotlighting issues such as vanishing or exploding parameters.
Components
Metrics can be computed on different types of components, which fall into two categories:
- Simple components: Individual weight matrices, gradients, or activations from a single layer.
- Compound components: Higher-level structures that combine multiple model elements. One example is the OV circuit, which tracks how information flows in transformer models by combining the value projection and output projection matrices in self-attention layers.
This two-step abstraction is designed for extensibility—new metrics and component types can be easily defined, allowing researchers to tailor analyses to specific hypotheses about language model learning.
To streamline experimentation, pico-analyze employs a YAML-based configuration system, enabling users to specify which layers, metrics, and training steps to run analyzes over. Results can be stored locally or uploaded to Weights & Biases (wandb) for visualization, allowing for efficient tracking of learning dynamics across multiple runs.
Components System
Components are functional objects generated from a given checkpoint state and a component configuration. They are at the heart of developmental interpretability, as they specify which parts of the model to analyze.
When defining components in your configuration, you need to specify three key aspects:
- Layer indices: Which specific layers of the model to analyze (e.g.,
layers: [0, 1, 2, 3]
) - Layer suffixes: What specific parts within those layers to target (e.g.,
layer_suffixes: "attention.q_proj"
) - Data type: Whether you're analyzing weights, activations, or gradients (e.g.,
data_type: "weights"
)
The data type is particularly important as it determines what aspect of the component you're analyzing:
weights
: Analyzing the learned parameters themselvesactivations
: Analyzing the outputs when running a forward pass (requires input data)gradients
: Analyzing how parameters are updated during training
Simple Components
Simple components target individual model elements from a single layer, such as:
- Weight matrices from any layer (attention projections, MLP layers, etc.)
- Gradients flowing through specific parts of the network
- Activations produced during forward passes with training or validation data
These components can be extracted directly from the stored checkpoint data without much additional computation. The implementation logic for simple components can be found in src/components/simple.py
.
For example, the following configuration specifies analyzing the weight matrix of the second SwiGLU projection in all layers:
components:
- component_name: simple
data_type: "weights"
layer_suffixes: "swiglu.w_2"
layers: [0, 1, 2, 3, 4, 5]
Compound Components
Compound components combine multiple simple components to represent higher-level structures in the model. These often correspond to functional circuits that implement specific capabilities. The distinction between simple and compound components is implemented in src/components/base.py
.
Examples of compound components that might be interesting to study, include:
- Attention Circuits: Combination of Key and Value or Ouput and Value projections that operate in unison.
- Induction Heads: Specialized attention circuits that detect and complete patterns
- Residual Stream: Information and memory flow through a model.
Commpound component are computed through matrix multiplications or other operations on different activations, weights, or gradients to produce the desired analytical unit.
Out of the box, pico-analyze implements an OV Circuit Component in src/components/ov_circuit.py
; in the following section we provide more details on this circuit.
OV Circuit
A key compound component in pico-analyze is the Output-Value (OV) Circuit. In the Transformer Attention Module, the OV circuit determines how much attending to a token affects the logits of a model.
The idea of an OV-Circuit stems from the observation that in an attention module, the value and output projections always operate jointly and write into the 'residual stream'. Thus, it makes sense to treat the Output and Value matrices as a single 'OV-Circuit' matrix.
The implementation of OV Circuits can be found in src/components/ov_circuit.py
. Because the OV-Circuit operates 'per head' of the attention module, pico-analyze computes and returns the OV activations, weights, and gradients for each head separately and then concatenates them into one aggregated tensor.
Here's an example configuration for analyzing the OV circuit:
components:
- component_name: ov_circuit
data_type: "weights" # or "activations" or "gradients"
layer_suffixes:
output_layer: "attention.o_proj"
value_layer: "attention.v_proj"
layers: [0, 1, 2, 3]
Metrics
Metrics are quantitative measures that help you understand what's happening inside your model during training. pico-analyze provides a diverse set of metrics, each designed to capture different aspects of model learning dynamics. These metrics serve as diagnostic tools to identify patterns, track development, and gain insight into the internal representations formed throughout the training process.
The framework includes four primary categories of metrics:
- Rank Metrics: Measure the effective dimensionality of model representations
- Representation Similarity Metrics: Capture how activations relate across different steps or layers
- Norm Metrics: Quantify the magnitudes of weights, activations, or gradients
- Sparsity Metrics: Assess how distributed or concentrated values are within a tensor
Metric-Component Compatibility: It's important to note that not all metrics can be applied to all component data types. For instance, the CKA metric only works with activations
data type. The pico-analyze code will raise an error if you attempt to apply a metric to an incompatible component type.
Rank Metrics
Rank metrics help understand the dimensionality and stability of model components. pico-analyze implements two primary rank-based metrics:
Proportional Effective Rank (PER)
PER allows you to compare the effective rank of layers with different sizes consistently and indicates effective usage of the parameters in the model. It is the entropy over normalized singular values of a given matrix and can be computed for any layer in your model.
Example configuration:
# PER metric
- metric_name: per
data_split: "train"
components:
- component_name: simple
data_type: "weights" # or "gradients" or "activations"
layer_suffixes: "swiglu.w_2"
layers: [0, 1, 2, 3, 4, 5]
Condition Number
The Condition Number metric indicates the sensitivity of a model component to small input changes. It computes the ratio of largest to smallest singular values.
Example configuration:
# Condition number metric
- metric_name: condition_number
data_split: "val"
components:
- component_name: simple
data_type: "weights" # or "activations"
layer_suffixes: "swiglu.w_2"
layers: [0, 1, 2, 3, 4, 5]
Representation Similarity
A key question for developmental interpretability is: How similar are activations between two sets of checkpoints? pico-analyze implements two powerful metrics for measuring representation similarity:
Centered Kernel Alignment (CKA)
CKA can reliably identify correspondences between representations in networks trained from different initializations. In pico-analyze, we specifically compute the CKA between a model's OV circuit activations and MLP activations.
Example configuration:
# CKA metric
- metric_name: cka
target_checkpoint: 100
data_split: "val"
components:
- component_name: ov_circuit
data_type: "activations"
layer_suffixes:
output_layer: "attention.o_proj"
value_layer: "attention.v_proj"
layers: [0, 1, 2, 3, 4, 5]
Projection Weighted Canonical Correlation Analysis (PWCCA)
PWCCA is another metric that measures activation similarity across training and uses projections to emphasize important components.
Example configuration:
# PWCCA metric
- metric_name: pwcca
target_checkpoint: 100
data_split: "val"
components:
- component_name: ov_circuit
data_type: "activations"
layer_suffixes:
output_layer: "attention.o_proj"
value_layer: "attention.v_proj"
layers: [0, 1, 2, 3, 4, 5]
Both of these metrics only work with data_type: "activations"
.
Norm Metrics
Norm metrics measure the overall magnitude of model parameters, activations, or gradients. pico-analyze includes three norm metrics:
- Nuclear Norm: Sum of singular values, indicating overall matrix magnitude
- Frobenius Norm: Square root of the sum of squared elements, a standard matrix size measure
- Infinity Norm: Maximum absolute value in the matrix, showing the largest individual element
Example configuration:
# Norm metric
- metric_name: norm
data_split: "val"
norm_type: "nuclear" # or "frobenius" or "inf"
components:
- component_name: ov_circuit
data_type: "weights" # or "activations" or "gradients"
layer_suffixes:
output_layer: "attention.o_proj"
value_layer: "attention.v_proj"
layers: [0, 1, 2, 3, 4, 5]
Sparsity Metrics
Sparsity metrics measure how concentrated or distributed values are within model components. pico-analyze implements two sparsity metrics:
Gini Coefficient
The Gini Coefficient is a measure of inequality in a distribution, with values ranging from 0 (perfect equality) to 1 (perfect inequality). In the context of neural networks, it approximates the sparsity of a matrix via a weight distribution inequality.
Example configuration:
# Gini metric
- metric_name: gini
data_split: "val"
components:
- component_name: simple
data_type: "weights" # or "activations" or "gradients"
layer_suffixes: "swiglu.w_2"
layers: [0, 1, 2, 3, 4, 5]
The implementation in pico-analyze uses a memory-efficient algorithm that avoids creating the full pairwise difference matrix.
Hoyer Sparsity Metric
The Hoyer Sparsity Metric computes the ratio between the L1 and L2 norms of a tensor. It ranges from 1 (for a fully sparse vector with only one non-zero element) to 1/√n (for a dense vector with all equal elements).
Example configuration:
# Hoyer metric
- metric_name: hoyer
data_split: "train"
components:
- component_name: simple
data_type: "weights" # or "activations" or "gradients"
layer_suffixes: "swiglu.w_2"
layers: [0, 1, 2, 3, 4, 5]
Setup
Clone Repository
Getting started with pico-analyze is straightforward. First, clone the repository from GitHub:
git clone https://github.com/pico-lm/pico-analyze.git
cd pico-analyze
This will create a local copy of the pico-analyze codebase on your machine. The repository contains all the necessary code to analyze model checkpoints generated with pico-train or other compatible frameworks.
Environment Variables
pico-analyze integrates with Weights & Biases (wandb) for visualization and experiment tracking, and optionally with Hugging Face for accessing model checkpoints. Setting up the appropriate environment variables will enable seamless logging and comparison of analysis results.
Create a .env
file at the root of your pico-analyze directory:
# .env
export WANDB_API_KEY=your_wandb_key
export HF_TOKEN=your_huggingface_token # Optional, needed only for accessing private HF repos
To obtain your wandb access token, go to https://wandb.ai/authorize. For your Hugging Face token, visit https://huggingface.co/docs/hub/en/security-tokens.
Installing Dependencies
pico-analyze uses Poetry for dependency management, which ensures consistent environments across different machines. The simplest way to set up pico-analyze is to run the setup script:
source setup.sh
This script will check your environment, install necessary tools, and set up a Poetry virtual environment with all dependencies.
Getting Started
Configuration
pico-analyze uses YAML configuration files to define what analyses to run. These configuration files specify which model components to analyze, which metrics to apply, and at which training steps to perform the analysis.
The following snippet from our configs/demo.yaml template illustrates the basic configuration structure:
analysis_name: "your-analysis-name"
metrics:
# CKA metric (Comparative)
- metric_name: cka
target_checkpoint: 100
data_split: "val"
components:
- component_name: ov_circuit
data_type: "activations"
layer_suffixes:
output_layer: "attention.o_proj"
value_layer: "attention.v_proj"
layers: [0,11]
steps:
start: 0
end: 100
step: 50
monitoring:
output_dir: "analysis_results"
save_to_wandb: true
wandb:
entity: "your-wandb-entity"
project: "your-wandb-project"
The configuration file has several key sections:
- analysis_name: A unique identifier for this analysis run
- metrics: A list of metrics to compute, each with its own configuration
- steps: Which model checkpoints to analyze (by training step)
- monitoring: Settings for wandb integration
Each metric configuration specifies:
- metric_name: Which metric to compute (e.g., "gini", "per", "cka")
- data_split: Whether to use training or validation data (for activation-based metrics)
- components: A list of components to apply the metric to
Save this configuration file as analyze_config.yaml
in a location of your choice. You'll reference this file when running analyses.
Running Analyses
Once you have defined a configuration file, you need to specify what model to run this analysis over. You can do so be either by: specify a remote HuggingFace repo and branch or a local path to a model checkpoint directory.
poetry run analyze --config path/to/analyze_config.yaml --repo_id username/repo-name --branch main
poetry run analyze --config path/to/analyze_config.yaml --checkpoint_dir /path/to/checkpoints
For models trained with pico-train, the analysis tool will automatically detect the training configuration used to train that model.
Viewing Results
Analysis results are stored in two places:
- Locally: Results are saved to
./results/{analysis_name}/
by default, with each metric generating JSON and CSV files with detailed results. - Weights & Biases: If wandb integration is enabled, results are uploaded and visualized in your wandb project dashboard.
A typical workflow involves:
- Running an analysis with wandb enabled
- Exploring results in the wandb dashboard
- Identifying interesting patterns or anomalies
- Running more targeted analyses based on initial findings
The wandb dashboard provides interactive plots showing how metrics evolve across training steps, making it easy to spot trends and compare different model components.
For quick local exploration, each analysis run generates a comprehensive report in the results directory, including summary statistics and basic visualizations.
Customization
pico-analyze is designed to be extended with custom metrics and components. Here's how you can add your own:
Adding a Custom Metric
Creating a custom metric involves defining a metric configuration class and the metric implementation in their respective files:
First, define the metric configuration class in src/config/metrics.py
:
# src/config/metrics.py
from dataclasses import dataclass
from pico_analyze.config.base import BaseMetricConfig, register_metric_config
@dataclass
@register_metric_config("top_k_values")
class TopKValuesConfig(BaseMetricConfig):
"""Configuration for the top-k values metric."""
# Default parameters
k: int = 10 # Number of top values to analyze
absolute: bool = True # Whether to use absolute values
Then, implement the metric class in src/metrics/custom.py
:
# src/metrics/custom.py
from pico_analyze.metrics.base import BaseMetric, register_metric
import torch
@register_metric("top_k_values")
class TopKValuesMetric(BaseMetric):
"""Compute statistics on the top-k largest values in a tensor."""
def __init__(self, config):
super().__init__(config)
# Read parameters from the metric config
self.k = config.k
self.absolute = config.absolute
def compute(self, component_tensor, **kwargs):
# Flatten the tensor to find top values
flat_tensor = component_tensor.flatten()
# Get top-k values (by absolute value if configured)
if self.absolute:
values, _ = torch.topk(torch.abs(flat_tensor), min(self.k, flat_tensor.numel()))
else:
values, _ = torch.topk(flat_tensor, min(self.k, flat_tensor.numel()))
# Compute statistics
return {
"top_k_mean": torch.mean(values).item(),
"top_k_max": torch.max(values).item(),
"top_1_to_k_ratio": (values[0] / (values[-1] + 1e-9)).item()
}
Once you've implemented both parts, include your metric in the analysis configuration:
# configs/analysis_config.yaml
analysis_name: "top_values_analysis"
steps: [0, 100, 500, 1000]
metrics:
- metric_name: top_k_values
k: 5 # Parameter defined in TopKValuesConfig
absolute: false # Parameter defined in TopKValuesConfig
data_split: "train"
components:
- component_name: simple
data_type: "weights"
layer_suffixes: "attention.q_proj"
layers: [0, 1, 2]
Adding a Custom Component
Components extract specific parts of your model for analysis:
# src/components/custom.py
from pico_analyze.components.base import BaseComponent, register_component
import torch
@register_component("key_value_circuit")
class KeyValueCircuitComponent(BaseComponent):
"""Extract key and value matrices and compute their interaction."""
def __init__(self, config):
super().__init__(config)
# Define required layer suffixes
def extract(self, checkpoint_state, data_type, layer_idx, **kwargs):
# Get layer suffixes from config
key_suffix = self.layer_suffixes["key_layer"]
value_suffix = self.layer_suffixes["value_layer"]
# Full layer names including index
key_layer = f"model.layers.{layer_idx}.{key_suffix}"
value_layer = f"model.layers.{layer_idx}.{value_suffix}"
if data_type == "weights":
# Extract weight matrices
key_weights = checkpoint_state[key_layer + ".weight"]
value_weights = checkpoint_state[value_layer + ".weight"]
# Simple matrix multiplication in this example
kv_circuit = torch.matmul(key_weights, value_weights.transpose(-1, -2))
return kv_circuit
else:
raise ValueError(f"Unsupported data_type: {data_type}")