drrik

Drrik: A Framework for Monosemantic Feature Extraction from Language Models

This framework is inspired by the Towards Monosemanticity paper from Anthropic, which applies dictionary learning to extract activations as features from transformer-based large language models and trains sparse autoencoders to linearize those activations.

Key Components:
  • Model loading from HuggingFace Hub (with gated model support)
  • Dataset loading and inference pipeline
  • MLP activation collection using nnsight
  • Sparse Autoencoder training for feature extraction
  • SAE-based activation steering to bias model outputs
  • Visualization of feature-specific activation vectors
  • Optional Weights & Biases integration for experiment tracking
Public API:

ActivationExtractor Loads models and datasets, extracts MLP activations via nnsight. SparseAutoencoder Overcomplete SAE with L1 regularization and dead neuron resampling. FeatureVisualizer Generates density histograms, training curves, and feature dashboards. SAESteering Steers language model generation by adding SAE feature directions to MLP activations during inference. Config Top-level Pydantic settings model aggregating all sub-configurations. EnvironmentSettings Loads API keys and environment variables from .env. WandbConfig Context manager for wandb experiment tracking. get_settings Returns the global EnvironmentSettings singleton.

Example Usage:
from drrik import ActivationExtractor, SparseAutoencoder, FeatureVisualizer

# Extract activations
extractor = ActivationExtractor(
    model_name="google/gemma-2b",
    dataset_name="wikitext",
    dataset_split="train",
    mlp_layers=[0, 1, 2],
    num_samples=1000
)
activations = extractor.extract()

# Train sparse autoencoder (with optional wandb logging)
sae = SparseAutoencoder(
    activation_dim=activations.shape[-1],
    hidden_dim=activations.shape[-1] * 8,  # 8x expansion
    l1_coefficient=0.01
)
sae.fit(activations, wandb_enabled=True)

# Visualize features (with optional wandb logging)
visualizer = FeatureVisualizer(sae, activations)
visualizer.plot_feature_density()
visualizer.plot_top_activating_examples()
 1"""
 2Drrik: A Framework for Monosemantic Feature Extraction from Language Models
 3
 4This framework is inspired by the ``Towards Monosemanticity`` paper from
 5Anthropic, which applies dictionary learning to extract activations as
 6features from transformer-based large language models and trains sparse
 7autoencoders to linearize those activations.
 8
 9Key Components:
10    - **Model loading** from HuggingFace Hub (with gated model support)
11    - **Dataset loading** and inference pipeline
12    - **MLP activation collection** using nnsight
13    - **Sparse Autoencoder training** for feature extraction
14    - **SAE-based activation steering** to bias model outputs
15    - **Visualization** of feature-specific activation vectors
16    - Optional **Weights & Biases** integration for experiment tracking
17
18Public API:
19    ``ActivationExtractor``
20        Loads models and datasets, extracts MLP activations via nnsight.
21    ``SparseAutoencoder``
22        Overcomplete SAE with L1 regularization and dead neuron resampling.
23    ``FeatureVisualizer``
24        Generates density histograms, training curves, and feature dashboards.
25    ``SAESteering``
26        Steers language model generation by adding SAE feature directions
27        to MLP activations during inference.
28    ``Config``
29        Top-level Pydantic settings model aggregating all sub-configurations.
30    ``EnvironmentSettings``
31        Loads API keys and environment variables from ``.env``.
32    ``WandbConfig``
33        Context manager for wandb experiment tracking.
34    ``get_settings``
35        Returns the global ``EnvironmentSettings`` singleton.
36
37Example Usage:
38    ```python
39    from drrik import ActivationExtractor, SparseAutoencoder, FeatureVisualizer
40
41    # Extract activations
42    extractor = ActivationExtractor(
43        model_name="google/gemma-2b",
44        dataset_name="wikitext",
45        dataset_split="train",
46        mlp_layers=[0, 1, 2],
47        num_samples=1000
48    )
49    activations = extractor.extract()
50
51    # Train sparse autoencoder (with optional wandb logging)
52    sae = SparseAutoencoder(
53        activation_dim=activations.shape[-1],
54        hidden_dim=activations.shape[-1] * 8,  # 8x expansion
55        l1_coefficient=0.01
56    )
57    sae.fit(activations, wandb_enabled=True)
58
59    # Visualize features (with optional wandb logging)
60    visualizer = FeatureVisualizer(sae, activations)
61    visualizer.plot_feature_density()
62    visualizer.plot_top_activating_examples()
63    ```
64"""
65
66__version__ = "0.1.0"
67
68from drrik.models import ActivationExtractor
69from drrik.autoencoder import SparseAutoencoder
70from drrik.visualization import FeatureVisualizer
71from drrik.steering import SAESteering
72from drrik.config import Config
73from drrik.settings import EnvironmentSettings, WandbConfig, get_settings
74
75__all__ = [
76    "ActivationExtractor",
77    "SparseAutoencoder",
78    "FeatureVisualizer",
79    "SAESteering",
80    "Config",
81    "EnvironmentSettings",
82    "WandbConfig",
83    "get_settings",
84]
class ActivationExtractor:
 45class ActivationExtractor:
 46    """
 47    Extract MLP activations from language models using nnsight.
 48
 49    Handles model and dataset loading from HuggingFace Hub, runs
 50    inference with nnsight's tracing context to capture MLP outputs,
 51    and provides persistence via ``.npy`` / ``.pkl`` files.
 52
 53    The HuggingFace auth token is read automatically from the
 54    ``HUGGINGFACE_HUB_TOKEN`` environment variable (or ``.env`` file)
 55    via :func:`~drrik.settings.get_settings`.
 56
 57    Attributes:
 58        config: The resolved :class:`~drrik.config.ActivationExtractorConfig`.
 59        model: The loaded nnsight ``LanguageModel`` (``None`` until
 60            :meth:`load_model` is called).
 61        tokenizer: The HuggingFace tokenizer (``None`` until
 62            :meth:`load_model` is called).
 63        dataset: The loaded HuggingFace ``Dataset`` (``None`` until
 64            :meth:`load_dataset` is called).
 65
 66    Example:
 67        ```python
 68        extractor = ActivationExtractor(
 69            model_name="google/gemma-2b",
 70            dataset_name="wikitext",
 71            dataset_config="wikitext-2-raw-v1",
 72            mlp_layers=[0],
 73            num_samples=1000,
 74        )
 75        activations, metadata = extractor.extract()
 76        extractor.save_activations(activations, metadata, "./output")
 77        ```
 78    """
 79
 80    def __init__(self, config: Optional[ActivationExtractorConfig] = None, **kwargs):
 81        """
 82        Initialize the ActivationExtractor.
 83
 84        Accepts either a pre-built config object or flat keyword
 85        arguments that are automatically sorted into
 86        :class:`~drrik.config.ModelConfig`,
 87        :class:`~drrik.config.DatasetConfig`, and the remaining
 88        extraction-level settings.
 89
 90        Args:
 91            config: A fully constructed
 92                :class:`~drrik.config.ActivationExtractorConfig`.
 93                If ``None``, the config is built from ``**kwargs``.
 94            **kwargs: Flat config overrides. Recognised keys are
 95                distributed to the appropriate sub-config:
 96
 97                - *Model*: ``model_name``, ``revision``,
 98                  ``torch_dtype``, ``device_map``,
 99                  ``trust_remote_code``
100                - *Dataset*: ``dataset_name``, ``dataset_config``,
101                  ``split``, ``num_samples``, ``text_column``,
102                  ``max_length``
103                - *Extraction*: ``mlp_layers``, ``batch_size``,
104                  ``output_dir``
105        """
106        if config is None:
107            model_kwargs = {
108                k: v
109                for k, v in kwargs.items()
110                if k
111                in (
112                    "model_name",
113                    "revision",
114                    "torch_dtype",
115                    "device_map",
116                    "trust_remote_code",
117                )
118            }
119            dataset_kwargs = {
120                k: v
121                for k, v in kwargs.items()
122                if k
123                in (
124                    "dataset_name",
125                    "dataset_config",
126                    "split",
127                    "num_samples",
128                    "text_column",
129                    "max_length",
130                )
131            }
132            remaining_kwargs = {
133                k: v
134                for k, v in kwargs.items()
135                if k not in model_kwargs and k not in dataset_kwargs
136            }
137
138            model = ModelConfig(**model_kwargs) if model_kwargs else None
139            dataset = DatasetConfig(**dataset_kwargs) if dataset_kwargs else None
140
141            config = ActivationExtractorConfig(
142                model=model,
143                dataset=dataset,
144                **remaining_kwargs,
145            )
146
147        self.config = config
148        self.model = None
149        self.tokenizer = None
150        self.dataset = None
151        self._activations = []
152        self._metadata = []
153
154    def load_model(self) -> LanguageModel:
155        """
156        Load the model from HuggingFace Hub using nnsight.
157
158        If the model is already loaded, returns the cached instance.
159        The HuggingFace token is read from ``HUGGINGFACE_HUB_TOKEN``
160        via :func:`~drrik.settings.get_settings`.
161
162        Returns:
163            The loaded nnsight :class:`~nnsight.LanguageModel` wrapper.
164
165        Raises:
166            RuntimeError: If model loading fails.
167        """
168        if self.model is not None:
169            return self.model
170
171        try:
172            logger.info(f"Loading model: {self.config.model.model_name}")
173
174            # Get HF token from settings
175            settings = get_settings()
176            hf_token = settings.huggingface_hub_token
177
178            if hf_token:
179                logger.info("Using HuggingFace Hub token for authentication")
180
181            # Load tokenizer with token
182            tokenizer_kwargs = {
183                "revision": self.config.model.revision,
184                "trust_remote_code": self.config.model.trust_remote_code,
185            }
186            if hf_token:
187                tokenizer_kwargs["token"] = hf_token
188
189            self.tokenizer = AutoTokenizer.from_pretrained(
190                self.config.model.model_name, **tokenizer_kwargs
191            )
192
193            if self.tokenizer.pad_token is None:
194                self.tokenizer.pad_token = self.tokenizer.eos_token
195
196            # Determine dtype
197            dtype_map = {
198                "float16": torch.float16,
199                "bfloat16": torch.bfloat16,
200                "float32": torch.float32,
201            }
202            torch_dtype = dtype_map.get(self.config.model.torch_dtype, torch.float16)
203
204            # Load model with nnsight (with token if available)
205            nnsight_kwargs = {
206                "revision": self.config.model.revision,
207                "torch_dtype": torch_dtype,
208                "trust_remote_code": self.config.model.trust_remote_code,
209                "device_map": self.config.model.device_map,
210            }
211            if hf_token:
212                nnsight_kwargs["token"] = hf_token
213
214            self.model = LanguageModel(self.config.model.model_name, **nnsight_kwargs)
215
216            logger.info(f"Model loaded successfully on {self.model.device}")
217            return self.model
218
219        except Exception as e:
220            logger.error(f"Failed to load model: {e}")
221            raise RuntimeError(f"Model loading failed: {e}") from e
222
223    def load_dataset(self) -> Dataset:
224        """
225        Load the dataset from HuggingFace Hub.
226
227        If the dataset is already loaded, returns the cached instance.
228        The HuggingFace token is forwarded for gated datasets when
229        available.
230
231        Returns:
232            The loaded HuggingFace :class:`~datasets.Dataset`.
233
234        Raises:
235            RuntimeError: If dataset loading fails.
236        """
237        if self.dataset is not None:
238            return self.dataset
239
240        try:
241            logger.info(
242                f"Loading dataset: {self.config.dataset.dataset_name} "
243                f"({self.config.dataset.split} split)"
244            )
245
246            # Get HF token from settings (for gated datasets)
247            settings = get_settings()
248            hf_token = settings.huggingface_hub_token
249
250            # Load dataset with token if available
251            load_kwargs = {
252                "path": self.config.dataset.dataset_name,
253                "name": self.config.dataset.dataset_config,
254                "split": self.config.dataset.split,
255            }
256            if hf_token:
257                load_kwargs["token"] = hf_token
258
259            self.dataset = load_dataset(**load_kwargs)
260
261            logger.info(f"Dataset loaded with {len(self.dataset)} examples")
262            return self.dataset
263
264        except Exception as e:
265            logger.error(f"Failed to load dataset: {e}")
266            raise RuntimeError(f"Dataset loading failed: {e}") from e
267
268    def _get_mlp_layer_name(self, layer_idx: int) -> str:
269        """
270        Get the nnsight module path to an MLP layer.
271
272        Auto-detects the correct path pattern based on the model name
273        in the config.  For unknown architectures a warning is logged
274        and the default Gemma/Llama pattern is returned.
275
276        Args:
277            layer_idx: Zero-indexed transformer layer number.
278
279        Returns:
280            A dotted/bracket module path string, e.g.
281            ``"model.layers[3].mlp"``.
282        """
283        model_name_lower = self.config.model.model_name.lower()
284
285        # Gemma/GPT-style: model.layers.N.mlp
286        if any(name in model_name_lower for name in ["gemma", "gpt-2", "pythia"]):
287            return f"model.layers[{layer_idx}].mlp"
288
289        # Llama-style: model.layers.N.mlp
290        if "llama" in model_name_lower:
291            return f"model.layers[{layer_idx}].mlp"
292
293        # Phi-style: model.layers.N.mlp
294        if "phi" in model_name_lower:
295            return f"model.layers[{layer_idx}].mllp"  # Note: mlp can be 'mlp' or 'mlp'
296
297        # BERT-style: bert.encoder.layer.N.output.dense
298        if "bert" in model_name_lower:
299            return f"bert.encoder.layer[{layer_idx}].output"
300
301        # Default: try common patterns
302        common_patterns = [
303            f"model.layers[{layer_idx}].mlp",
304            f"model.layers[{layer_idx}].ffn",
305            f"transformer.h[{layer_idx}].mlp",
306            f"layers[{layer_idx}].mlp",
307        ]
308
309        logger.warning(
310            f"Unknown model architecture '{self.config.model.model_name}', "
311            f"trying common patterns"
312        )
313        return common_patterns[0]
314
315    def _resolve_layer_path(self, path: str):
316        """Resolve a dotted/bracket path to the actual nnsight module.
317
318        Delegates to the shared :func:`~drrik.steering.resolve_module_path`
319        utility so that path-resolution logic is maintained in one place.
320
321        Args:
322            path: A dot-separated module path with optional bracket
323                indexing, e.g., ``model.layers[2].mlp``.
324
325        Returns:
326            The resolved nnsight module object corresponding to the
327            given path.
328        """
329        return resolve_module_path(self.model, path)
330
331    def extract(
332        self,
333        num_samples: Optional[int] = None,
334    ) -> Tuple[np.ndarray, Dict[str, Any]]:
335        """
336        Extract MLP activations from the model.
337
338        Loads the model and dataset (if not already loaded), tokenizes
339        the data, and runs batched forward passes using nnsight's
340        tracing context.  For each sample, the **last-token** activation
341        from each requested MLP layer is collected and concatenated.
342
343        Args:
344            num_samples: Override for the number of samples to process.
345                If ``None``, uses ``config.dataset.num_samples``.
346
347        Returns:
348            A tuple of:
349
350            - **activations** (:class:`numpy.ndarray`) — shape
351              ``(n_samples, activation_dim * n_layers)``.  When a
352              single layer is requested the shape simplifies to
353              ``(n_samples, activation_dim)``.
354            - **metadata** (:class:`dict`) — contains the serialised
355              config, sample count, activation dimension, layer paths,
356              and per-sample metadata (index, truncated text, input
357              ids).
358
359        Raises:
360            RuntimeError: If extraction fails at any stage.
361        """
362        try:
363            # Load model and dataset
364            self.load_model()
365            self.load_dataset()
366
367            n_samples = num_samples or self.config.dataset.num_samples
368            logger.info(
369                f"Extracting activations from {len(self.config.mlp_layers)} MLP layers "
370                f"for {n_samples} samples"
371            )
372
373            # Prepare dataset
374            dataset = self.dataset.select(range(min(n_samples, len(self.dataset))))
375
376            # Tokenize dataset
377            def tokenize_function(examples):
378                return self.tokenizer(
379                    examples[self.config.dataset.text_column],
380                    padding="max_length",
381                    truncation=True,
382                    max_length=self.config.dataset.max_length,
383                    return_tensors="pt",
384                )
385
386            tokenized = dataset.map(
387                tokenize_function,
388                batched=True,
389                remove_columns=dataset.column_names,
390                desc="Tokenizing",
391            )
392
393            # Get MLP layer names
394            layer_paths = [
395                self._get_mlp_layer_name(layer_idx)
396                for layer_idx in self.config.mlp_layers
397            ]
398
399            logger.info(f"Extracting from layers: {layer_paths}")
400
401            # Collect activations using nnsight
402            self._activations = []
403            self._metadata = []
404
405            batch_size = self.config.batch_size
406            n_batches = (len(tokenized) + batch_size - 1) // batch_size
407
408            with torch.no_grad():
409                for batch_idx in tqdm(range(n_batches), desc="Extracting activations"):
410                    start_idx = batch_idx * batch_size
411                    end_idx = min(start_idx + batch_size, len(tokenized))
412
413                    batch = tokenized[start_idx:end_idx]
414                    input_ids = torch.tensor(batch["input_ids"])
415                    attention_mask = torch.tensor(batch["attention_mask"])
416
417                    layer_outputs = []
418
419                    # Use nnsight to extract activations
420                    with self.model.trace(input_ids, attention_mask=attention_mask):
421                        for layer_path in layer_paths:
422                            module = self._resolve_layer_path(layer_path)
423                            output = module.output.save()
424                            layer_outputs.append(output)
425
426                    # Process outputs
427                    batch_input_ids = input_ids.cpu().numpy()
428                    for sample_idx in range(len(input_ids)):
429                        sample_activations = []
430                        for layer_output in layer_outputs:
431                            # Get activation after non-linearity (if applicable)
432                            # Shape: (seq_len, hidden_dim) or (batch, seq_len, hidden_dim)
433                            act = layer_output
434
435                            if act.dim() == 3:
436                                act = act[sample_idx]  # (seq_len, hidden_dim)
437                            elif act.dim() == 2:
438                                act = act  # Already (seq_len, hidden_dim)
439
440                            # Use the last token's activation (common practice)
441                            act = act[-1]  # (hidden_dim,)
442
443                            sample_activations.append(act.cpu().numpy())
444
445                        # Concatenate activations from all layers
446                        self._activations.append(np.concatenate(sample_activations))
447
448                        # Store metadata
449                        self._metadata.append(
450                            {
451                                "sample_idx": start_idx + sample_idx,
452                                "text": dataset[start_idx + sample_idx][
453                                    self.config.dataset.text_column
454                                ][:200],
455                                "input_ids": batch_input_ids[sample_idx],
456                            }
457                        )
458
459            activations = np.array(self._activations)
460            logger.info(f"Extracted activations shape: {activations.shape}")
461
462            metadata = {
463                "config": self.config.model_dump(),
464                "n_samples": len(self._activations),
465                "activation_dim": activations.shape[-1],
466                "layer_paths": layer_paths,
467                "samples_metadata": self._metadata,
468            }
469
470            return activations, metadata
471
472        except Exception as e:
473            logger.error(f"Failed to extract activations: {e}")
474            raise RuntimeError(f"Activation extraction failed: {e}") from e
475
476    def save_activations(
477        self,
478        activations: np.ndarray,
479        metadata: Dict[str, Any],
480        output_dir: Optional[Union[str, Path]] = None,
481    ) -> Path:
482        """
483        Save extracted activations to disk.
484
485        Writes two files to *output_dir*:
486
487        - ``activations.npy`` — the activation array.
488        - ``metadata.pkl`` — the metadata dictionary.
489
490        Args:
491            activations: Activation array of shape
492                ``(n_samples, dim)``.
493            metadata: Metadata dictionary (as returned by
494                :meth:`extract`).
495            output_dir: Target directory.  If ``None``, uses
496                ``config.output_dir``.
497
498        Returns:
499            Path to the saved ``activations.npy`` file.
500
501        Raises:
502            ValueError: If *output_dir* is ``None`` and no output
503                directory is configured.
504        """
505        if output_dir is None:
506            output_dir = self.config.output_dir
507
508        output_dir = Path(output_dir)
509        output_dir.mkdir(parents=True, exist_ok=True)
510
511        activations_path = output_dir / "activations.npy"
512        np.save(str(activations_path), activations)
513
514        metadata_path = output_dir / "metadata.pkl"
515        with open(metadata_path, "wb") as f:
516            pickle.dump(metadata, f)
517
518        logger.info(f"Saved activations to {activations_path}")
519        return activations_path
520
521    def load_activations(
522        self,
523        filepath: Union[str, Path],
524    ) -> Tuple[np.ndarray, Dict[str, Any]]:
525        """
526        Load saved activations from disk.
527
528        Supports two formats:
529
530        - ``.npy`` — loads the array directly and looks for a
531          companion ``metadata.pkl`` in the same directory.
532        - ``.pkl`` (legacy) — loads a dict with ``"activations"``
533          and ``"metadata"`` keys.
534
535        Args:
536            filepath: Path to an ``.npy`` or ``.pkl`` file.
537
538        Returns:
539            A tuple of (activations array, metadata dict).
540
541        Raises:
542            FileNotFoundError: If the file (or companion metadata) does
543                not exist.
544        """
545        filepath = Path(filepath)
546
547        if filepath.suffix == ".npy":
548            activations = np.load(str(filepath))
549            metadata_path = filepath.parent / "metadata.pkl"
550            with open(metadata_path, "rb") as f:
551                metadata = pickle.load(f)
552        else:
553            with open(filepath, "rb") as f:
554                data = pickle.load(f)
555            activations = data["activations"]
556            metadata = data["metadata"]
557
558        logger.info(f"Loaded activations from {filepath}")
559        return activations, metadata

Extract MLP activations from language models using nnsight.

Handles model and dataset loading from HuggingFace Hub, runs inference with nnsight's tracing context to capture MLP outputs, and provides persistence via .npy / .pkl files.

The HuggingFace auth token is read automatically from the HUGGINGFACE_HUB_TOKEN environment variable (or .env file) via ~drrik.settings.get_settings().

Attributes:
  • config: The resolved ~drrik.config.ActivationExtractorConfig.
  • model: The loaded nnsight LanguageModel (None until load_model() is called).
  • tokenizer: The HuggingFace tokenizer (None until load_model() is called).
  • dataset: The loaded HuggingFace Dataset (None until load_dataset() is called).
Example:
extractor = ActivationExtractor(
    model_name="google/gemma-2b",
    dataset_name="wikitext",
    dataset_config="wikitext-2-raw-v1",
    mlp_layers=[0],
    num_samples=1000,
)
activations, metadata = extractor.extract()
extractor.save_activations(activations, metadata, "./output")
ActivationExtractor( config: Optional[drrik.config.ActivationExtractorConfig] = None, **kwargs)
 80    def __init__(self, config: Optional[ActivationExtractorConfig] = None, **kwargs):
 81        """
 82        Initialize the ActivationExtractor.
 83
 84        Accepts either a pre-built config object or flat keyword
 85        arguments that are automatically sorted into
 86        :class:`~drrik.config.ModelConfig`,
 87        :class:`~drrik.config.DatasetConfig`, and the remaining
 88        extraction-level settings.
 89
 90        Args:
 91            config: A fully constructed
 92                :class:`~drrik.config.ActivationExtractorConfig`.
 93                If ``None``, the config is built from ``**kwargs``.
 94            **kwargs: Flat config overrides. Recognised keys are
 95                distributed to the appropriate sub-config:
 96
 97                - *Model*: ``model_name``, ``revision``,
 98                  ``torch_dtype``, ``device_map``,
 99                  ``trust_remote_code``
100                - *Dataset*: ``dataset_name``, ``dataset_config``,
101                  ``split``, ``num_samples``, ``text_column``,
102                  ``max_length``
103                - *Extraction*: ``mlp_layers``, ``batch_size``,
104                  ``output_dir``
105        """
106        if config is None:
107            model_kwargs = {
108                k: v
109                for k, v in kwargs.items()
110                if k
111                in (
112                    "model_name",
113                    "revision",
114                    "torch_dtype",
115                    "device_map",
116                    "trust_remote_code",
117                )
118            }
119            dataset_kwargs = {
120                k: v
121                for k, v in kwargs.items()
122                if k
123                in (
124                    "dataset_name",
125                    "dataset_config",
126                    "split",
127                    "num_samples",
128                    "text_column",
129                    "max_length",
130                )
131            }
132            remaining_kwargs = {
133                k: v
134                for k, v in kwargs.items()
135                if k not in model_kwargs and k not in dataset_kwargs
136            }
137
138            model = ModelConfig(**model_kwargs) if model_kwargs else None
139            dataset = DatasetConfig(**dataset_kwargs) if dataset_kwargs else None
140
141            config = ActivationExtractorConfig(
142                model=model,
143                dataset=dataset,
144                **remaining_kwargs,
145            )
146
147        self.config = config
148        self.model = None
149        self.tokenizer = None
150        self.dataset = None
151        self._activations = []
152        self._metadata = []

Initialize the ActivationExtractor.

Accepts either a pre-built config object or flat keyword arguments that are automatically sorted into ~drrik.config.ModelConfig, ~drrik.config.DatasetConfig, and the remaining extraction-level settings.

Arguments:
  • config: A fully constructed ~drrik.config.ActivationExtractorConfig. If None, the config is built from **kwargs.
  • **kwargs: Flat config overrides. Recognised keys are distributed to the appropriate sub-config:

    • Model: model_name, revision, torch_dtype, device_map, trust_remote_code
    • Dataset: dataset_name, dataset_config, split, num_samples, text_column, max_length
    • Extraction: mlp_layers, batch_size, output_dir
config
model
tokenizer
dataset
def load_model(self) -> nnsight.modeling.language.LanguageModel:
154    def load_model(self) -> LanguageModel:
155        """
156        Load the model from HuggingFace Hub using nnsight.
157
158        If the model is already loaded, returns the cached instance.
159        The HuggingFace token is read from ``HUGGINGFACE_HUB_TOKEN``
160        via :func:`~drrik.settings.get_settings`.
161
162        Returns:
163            The loaded nnsight :class:`~nnsight.LanguageModel` wrapper.
164
165        Raises:
166            RuntimeError: If model loading fails.
167        """
168        if self.model is not None:
169            return self.model
170
171        try:
172            logger.info(f"Loading model: {self.config.model.model_name}")
173
174            # Get HF token from settings
175            settings = get_settings()
176            hf_token = settings.huggingface_hub_token
177
178            if hf_token:
179                logger.info("Using HuggingFace Hub token for authentication")
180
181            # Load tokenizer with token
182            tokenizer_kwargs = {
183                "revision": self.config.model.revision,
184                "trust_remote_code": self.config.model.trust_remote_code,
185            }
186            if hf_token:
187                tokenizer_kwargs["token"] = hf_token
188
189            self.tokenizer = AutoTokenizer.from_pretrained(
190                self.config.model.model_name, **tokenizer_kwargs
191            )
192
193            if self.tokenizer.pad_token is None:
194                self.tokenizer.pad_token = self.tokenizer.eos_token
195
196            # Determine dtype
197            dtype_map = {
198                "float16": torch.float16,
199                "bfloat16": torch.bfloat16,
200                "float32": torch.float32,
201            }
202            torch_dtype = dtype_map.get(self.config.model.torch_dtype, torch.float16)
203
204            # Load model with nnsight (with token if available)
205            nnsight_kwargs = {
206                "revision": self.config.model.revision,
207                "torch_dtype": torch_dtype,
208                "trust_remote_code": self.config.model.trust_remote_code,
209                "device_map": self.config.model.device_map,
210            }
211            if hf_token:
212                nnsight_kwargs["token"] = hf_token
213
214            self.model = LanguageModel(self.config.model.model_name, **nnsight_kwargs)
215
216            logger.info(f"Model loaded successfully on {self.model.device}")
217            return self.model
218
219        except Exception as e:
220            logger.error(f"Failed to load model: {e}")
221            raise RuntimeError(f"Model loading failed: {e}") from e

Load the model from HuggingFace Hub using nnsight.

If the model is already loaded, returns the cached instance. The HuggingFace token is read from HUGGINGFACE_HUB_TOKEN via ~drrik.settings.get_settings().

Returns:

The loaded nnsight ~nnsight.LanguageModel wrapper.

Raises:
  • RuntimeError: If model loading fails.
def load_dataset(self) -> datasets.arrow_dataset.Dataset:
223    def load_dataset(self) -> Dataset:
224        """
225        Load the dataset from HuggingFace Hub.
226
227        If the dataset is already loaded, returns the cached instance.
228        The HuggingFace token is forwarded for gated datasets when
229        available.
230
231        Returns:
232            The loaded HuggingFace :class:`~datasets.Dataset`.
233
234        Raises:
235            RuntimeError: If dataset loading fails.
236        """
237        if self.dataset is not None:
238            return self.dataset
239
240        try:
241            logger.info(
242                f"Loading dataset: {self.config.dataset.dataset_name} "
243                f"({self.config.dataset.split} split)"
244            )
245
246            # Get HF token from settings (for gated datasets)
247            settings = get_settings()
248            hf_token = settings.huggingface_hub_token
249
250            # Load dataset with token if available
251            load_kwargs = {
252                "path": self.config.dataset.dataset_name,
253                "name": self.config.dataset.dataset_config,
254                "split": self.config.dataset.split,
255            }
256            if hf_token:
257                load_kwargs["token"] = hf_token
258
259            self.dataset = load_dataset(**load_kwargs)
260
261            logger.info(f"Dataset loaded with {len(self.dataset)} examples")
262            return self.dataset
263
264        except Exception as e:
265            logger.error(f"Failed to load dataset: {e}")
266            raise RuntimeError(f"Dataset loading failed: {e}") from e

Load the dataset from HuggingFace Hub.

If the dataset is already loaded, returns the cached instance. The HuggingFace token is forwarded for gated datasets when available.

Returns:

The loaded HuggingFace ~datasets.Dataset.

Raises:
  • RuntimeError: If dataset loading fails.
def extract( self, num_samples: Optional[int] = None) -> Tuple[numpy.ndarray, Dict[str, Any]]:
331    def extract(
332        self,
333        num_samples: Optional[int] = None,
334    ) -> Tuple[np.ndarray, Dict[str, Any]]:
335        """
336        Extract MLP activations from the model.
337
338        Loads the model and dataset (if not already loaded), tokenizes
339        the data, and runs batched forward passes using nnsight's
340        tracing context.  For each sample, the **last-token** activation
341        from each requested MLP layer is collected and concatenated.
342
343        Args:
344            num_samples: Override for the number of samples to process.
345                If ``None``, uses ``config.dataset.num_samples``.
346
347        Returns:
348            A tuple of:
349
350            - **activations** (:class:`numpy.ndarray`) — shape
351              ``(n_samples, activation_dim * n_layers)``.  When a
352              single layer is requested the shape simplifies to
353              ``(n_samples, activation_dim)``.
354            - **metadata** (:class:`dict`) — contains the serialised
355              config, sample count, activation dimension, layer paths,
356              and per-sample metadata (index, truncated text, input
357              ids).
358
359        Raises:
360            RuntimeError: If extraction fails at any stage.
361        """
362        try:
363            # Load model and dataset
364            self.load_model()
365            self.load_dataset()
366
367            n_samples = num_samples or self.config.dataset.num_samples
368            logger.info(
369                f"Extracting activations from {len(self.config.mlp_layers)} MLP layers "
370                f"for {n_samples} samples"
371            )
372
373            # Prepare dataset
374            dataset = self.dataset.select(range(min(n_samples, len(self.dataset))))
375
376            # Tokenize dataset
377            def tokenize_function(examples):
378                return self.tokenizer(
379                    examples[self.config.dataset.text_column],
380                    padding="max_length",
381                    truncation=True,
382                    max_length=self.config.dataset.max_length,
383                    return_tensors="pt",
384                )
385
386            tokenized = dataset.map(
387                tokenize_function,
388                batched=True,
389                remove_columns=dataset.column_names,
390                desc="Tokenizing",
391            )
392
393            # Get MLP layer names
394            layer_paths = [
395                self._get_mlp_layer_name(layer_idx)
396                for layer_idx in self.config.mlp_layers
397            ]
398
399            logger.info(f"Extracting from layers: {layer_paths}")
400
401            # Collect activations using nnsight
402            self._activations = []
403            self._metadata = []
404
405            batch_size = self.config.batch_size
406            n_batches = (len(tokenized) + batch_size - 1) // batch_size
407
408            with torch.no_grad():
409                for batch_idx in tqdm(range(n_batches), desc="Extracting activations"):
410                    start_idx = batch_idx * batch_size
411                    end_idx = min(start_idx + batch_size, len(tokenized))
412
413                    batch = tokenized[start_idx:end_idx]
414                    input_ids = torch.tensor(batch["input_ids"])
415                    attention_mask = torch.tensor(batch["attention_mask"])
416
417                    layer_outputs = []
418
419                    # Use nnsight to extract activations
420                    with self.model.trace(input_ids, attention_mask=attention_mask):
421                        for layer_path in layer_paths:
422                            module = self._resolve_layer_path(layer_path)
423                            output = module.output.save()
424                            layer_outputs.append(output)
425
426                    # Process outputs
427                    batch_input_ids = input_ids.cpu().numpy()
428                    for sample_idx in range(len(input_ids)):
429                        sample_activations = []
430                        for layer_output in layer_outputs:
431                            # Get activation after non-linearity (if applicable)
432                            # Shape: (seq_len, hidden_dim) or (batch, seq_len, hidden_dim)
433                            act = layer_output
434
435                            if act.dim() == 3:
436                                act = act[sample_idx]  # (seq_len, hidden_dim)
437                            elif act.dim() == 2:
438                                act = act  # Already (seq_len, hidden_dim)
439
440                            # Use the last token's activation (common practice)
441                            act = act[-1]  # (hidden_dim,)
442
443                            sample_activations.append(act.cpu().numpy())
444
445                        # Concatenate activations from all layers
446                        self._activations.append(np.concatenate(sample_activations))
447
448                        # Store metadata
449                        self._metadata.append(
450                            {
451                                "sample_idx": start_idx + sample_idx,
452                                "text": dataset[start_idx + sample_idx][
453                                    self.config.dataset.text_column
454                                ][:200],
455                                "input_ids": batch_input_ids[sample_idx],
456                            }
457                        )
458
459            activations = np.array(self._activations)
460            logger.info(f"Extracted activations shape: {activations.shape}")
461
462            metadata = {
463                "config": self.config.model_dump(),
464                "n_samples": len(self._activations),
465                "activation_dim": activations.shape[-1],
466                "layer_paths": layer_paths,
467                "samples_metadata": self._metadata,
468            }
469
470            return activations, metadata
471
472        except Exception as e:
473            logger.error(f"Failed to extract activations: {e}")
474            raise RuntimeError(f"Activation extraction failed: {e}") from e

Extract MLP activations from the model.

Loads the model and dataset (if not already loaded), tokenizes the data, and runs batched forward passes using nnsight's tracing context. For each sample, the last-token activation from each requested MLP layer is collected and concatenated.

Arguments:
  • num_samples: Override for the number of samples to process. If None, uses config.dataset.num_samples.
Returns:

A tuple of:

  • activations (numpy.ndarray) — shape (n_samples, activation_dim * n_layers). When a single layer is requested the shape simplifies to (n_samples, activation_dim).
  • metadata (dict) — contains the serialised config, sample count, activation dimension, layer paths, and per-sample metadata (index, truncated text, input ids).
Raises:
  • RuntimeError: If extraction fails at any stage.
def save_activations( self, activations: numpy.ndarray, metadata: Dict[str, Any], output_dir: Union[str, pathlib.Path, NoneType] = None) -> pathlib.Path:
476    def save_activations(
477        self,
478        activations: np.ndarray,
479        metadata: Dict[str, Any],
480        output_dir: Optional[Union[str, Path]] = None,
481    ) -> Path:
482        """
483        Save extracted activations to disk.
484
485        Writes two files to *output_dir*:
486
487        - ``activations.npy`` — the activation array.
488        - ``metadata.pkl`` — the metadata dictionary.
489
490        Args:
491            activations: Activation array of shape
492                ``(n_samples, dim)``.
493            metadata: Metadata dictionary (as returned by
494                :meth:`extract`).
495            output_dir: Target directory.  If ``None``, uses
496                ``config.output_dir``.
497
498        Returns:
499            Path to the saved ``activations.npy`` file.
500
501        Raises:
502            ValueError: If *output_dir* is ``None`` and no output
503                directory is configured.
504        """
505        if output_dir is None:
506            output_dir = self.config.output_dir
507
508        output_dir = Path(output_dir)
509        output_dir.mkdir(parents=True, exist_ok=True)
510
511        activations_path = output_dir / "activations.npy"
512        np.save(str(activations_path), activations)
513
514        metadata_path = output_dir / "metadata.pkl"
515        with open(metadata_path, "wb") as f:
516            pickle.dump(metadata, f)
517
518        logger.info(f"Saved activations to {activations_path}")
519        return activations_path

Save extracted activations to disk.

Writes two files to output_dir:

  • activations.npy — the activation array.
  • metadata.pkl — the metadata dictionary.
Arguments:
  • activations: Activation array of shape (n_samples, dim).
  • metadata: Metadata dictionary (as returned by extract()).
  • output_dir: Target directory. If None, uses config.output_dir.
Returns:

Path to the saved activations.npy file.

Raises:
  • ValueError: If output_dir is None and no output directory is configured.
def load_activations( self, filepath: Union[str, pathlib.Path]) -> Tuple[numpy.ndarray, Dict[str, Any]]:
521    def load_activations(
522        self,
523        filepath: Union[str, Path],
524    ) -> Tuple[np.ndarray, Dict[str, Any]]:
525        """
526        Load saved activations from disk.
527
528        Supports two formats:
529
530        - ``.npy`` — loads the array directly and looks for a
531          companion ``metadata.pkl`` in the same directory.
532        - ``.pkl`` (legacy) — loads a dict with ``"activations"``
533          and ``"metadata"`` keys.
534
535        Args:
536            filepath: Path to an ``.npy`` or ``.pkl`` file.
537
538        Returns:
539            A tuple of (activations array, metadata dict).
540
541        Raises:
542            FileNotFoundError: If the file (or companion metadata) does
543                not exist.
544        """
545        filepath = Path(filepath)
546
547        if filepath.suffix == ".npy":
548            activations = np.load(str(filepath))
549            metadata_path = filepath.parent / "metadata.pkl"
550            with open(metadata_path, "rb") as f:
551                metadata = pickle.load(f)
552        else:
553            with open(filepath, "rb") as f:
554                data = pickle.load(f)
555            activations = data["activations"]
556            metadata = data["metadata"]
557
558        logger.info(f"Loaded activations from {filepath}")
559        return activations, metadata

Load saved activations from disk.

Supports two formats:

  • .npy — loads the array directly and looks for a companion metadata.pkl in the same directory.
  • .pkl (legacy) — loads a dict with "activations" and "metadata" keys.
Arguments:
  • filepath: Path to an .npy or .pkl file.
Returns:

A tuple of (activations array, metadata dict).

Raises:
  • FileNotFoundError: If the file (or companion metadata) does not exist.
class SparseAutoencoder(torch.nn.modules.module.Module):
 47class SparseAutoencoder(nn.Module):
 48    """
 49    Sparse Autoencoder for extracting interpretable features from MLP activations.
 50
 51    Architecture:
 52        Input (activation_dim) -> [bias] -> Encoder (hidden_dim) -> ReLU -> [bias]
 53        -> Decoder (activation_dim) -> Output
 54
 55    The decoder weights are normalized to unit norm to prevent scaling collapse.
 56    L1 regularization on hidden activations encourages sparsity.
 57
 58    Based on the architecture described in:
 59    "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning"
 60    https://transformer-circuits.pub/2023/monosemantic-features/index.html
 61
 62    Attributes:
 63        encoder: Linear layer from activation_dim to hidden_dim
 64        decoder: Linear layer from hidden_dim to activation_dim
 65        pre_encoder_bias: Learnable bias subtracted from input
 66        post_decoder_bias: Learnable bias added to output (tied to pre_encoder_bias)
 67
 68    Example:
 69        ```python
 70        sae = SparseAutoencoder(
 71            activation_dim=2048,
 72            hidden_dim=4096,  # 2x expansion
 73            l1_coefficient=0.01
 74        )
 75        sae.fit(activations, num_epochs=100)
 76        features = sae.encode(activations)
 77        reconstructed = sae.decode(features)
 78        ```
 79    """
 80
 81    def __init__(
 82        self,
 83        activation_dim: int,
 84        hidden_dim: int,
 85        l1_coefficient: float = 0.01,
 86        normalize_decoder: bool = True,
 87        pre_encoder_bias: bool = True,
 88    ):
 89        """
 90        Initialize the Sparse Autoencoder.
 91
 92        Args:
 93            activation_dim: Dimension of input MLP activations
 94            hidden_dim: Dimension of sparse hidden layer (overcomplete)
 95            l1_coefficient: L1 regularization strength
 96            normalize_decoder: Whether to normalize decoder columns to unit norm
 97            pre_encoder_bias: Whether to use pre-encoder bias (as in the paper)
 98        """
 99        super().__init__()
100
101        self.activation_dim = activation_dim
102        self.hidden_dim = hidden_dim
103        self.l1_coefficient = l1_coefficient
104        self.normalize_decoder = normalize_decoder
105        self.pre_encoder_bias = pre_encoder_bias
106
107        # Encoder: input -> hidden
108        self.encoder = nn.Linear(activation_dim, hidden_dim, bias=False)
109
110        # Decoder: hidden -> output
111        # Initialize with small random weights
112        self.decoder = nn.Linear(hidden_dim, activation_dim, bias=False)
113
114        # Initialize decoder weights to unit norm (Kaiming uniform)
115        nn.init.kaiming_uniform_(self.decoder.weight, a=np.sqrt(5))
116        with torch.no_grad():
117            self.decoder.weight.copy_(
118                self.decoder.weight / self.decoder.weight.norm(dim=0, keepdim=True)
119            )
120
121        # Initialize encoder weights
122        nn.init.kaiming_uniform_(self.encoder.weight, a=np.sqrt(5))
123
124        # Biases
125        if pre_encoder_bias:
126            # Pre-encoder bias (subtracted from input, added to output)
127            # Initialize to geometric median of data (will be set during fit)
128            self.bias = nn.Parameter(torch.zeros(activation_dim))
129        else:
130            self.register_parameter("bias", None)
131
132        # Encoder bias
133        self.encoder_bias = nn.Parameter(torch.zeros(hidden_dim))
134
135        # Training metrics
136        self.training_losses = []
137        self.training_l0_norms = []
138
139    def encode(self, x: torch.Tensor) -> torch.Tensor:
140        """
141        Encode input activations to sparse features.
142
143        Args:
144            x: Input tensor of shape (batch_size, activation_dim)
145
146        Returns:
147            Sparse feature activations of shape (batch_size, hidden_dim)
148        """
149        # Subtract bias if using pre-encoder bias
150        if self.bias is not None:
151            x = x - self.bias
152
153        # Apply encoder
154        hidden = F.linear(x, self.encoder.weight, self.encoder_bias)
155
156        # ReLU activation for sparsity
157        features = F.relu(hidden)
158
159        return features
160
161    def decode(self, features: torch.Tensor) -> torch.Tensor:
162        """
163        Decode sparse features back to activation space.
164
165        Args:
166            features: Sparse feature tensor of shape (batch_size, hidden_dim)
167
168        Returns:
169            Reconstructed activations of shape (batch_size, activation_dim)
170        """
171        # Apply decoder (weights are normalized to unit norm)
172        output = F.linear(features, self.decoder.weight)
173
174        # Add bias if using
175        if self.bias is not None:
176            output = output + self.bias
177
178        return output
179
180    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
181        """
182        Forward pass through the autoencoder.
183
184        Args:
185            x: Input tensor of shape (batch_size, activation_dim)
186
187        Returns:
188            Tuple of (reconstructed, features)
189        """
190        features = self.encode(x)
191        reconstructed = self.decode(features)
192        return reconstructed, features
193
194    def loss(
195        self, x: torch.Tensor, reconstructed: torch.Tensor, features: torch.Tensor
196    ) -> torch.Tensor:
197        """
198        Compute the loss with L1 regularization.
199
200        Loss = MSE(reconstructed, x) + lambda * L1(features)
201
202        Args:
203            x: Input tensor
204            reconstructed: Reconstructed tensor
205            features: Sparse feature activations
206
207        Returns:
208            Total loss
209        """
210        mse_loss = F.mse_loss(reconstructed, x)
211        l1_loss = self.l1_coefficient * features.abs().sum(dim=1).mean()
212        return mse_loss + l1_loss
213
214    def normalize_decoder_weights(self) -> None:
215        """Normalize decoder weight columns to unit L2 norm.
216
217        This prevents scaling collapse where the optimizer could
218        trivially reduce the loss by shrinking decoder weights and
219        scaling up encoder weights. Applied after each optimizer step
220        when ``normalize_decoder`` is ``True``.
221        """
222        if self.normalize_decoder:
223            with torch.no_grad():
224                self.decoder.weight.copy_(
225                    self.decoder.weight
226                    / self.decoder.weight.norm(dim=0, keepdim=True).clamp(min=1e-8)
227                )
228
229    def resample_dead_neurons(
230        self,
231        activations: torch.Tensor,
232        dead_threshold: float = 1e-8,
233        dead_mask: Optional[torch.Tensor] = None,
234    ) -> int:
235        """
236        Resample dead neurons that haven't fired recently.
237
238        This follows the procedure from the paper:
239        1. Identify neurons that haven't fired above threshold
240        2. Compute loss on a batch of data
241        3. Resample dead neurons to fit poorly reconstructed examples
242
243        Args:
244            activations: Batch of activations to use for resampling
245            dead_threshold: Activation threshold below which a neuron is considered dead
246            dead_mask: Pre-computed boolean mask of dead neurons. If None, computed
247                       from the current batch. When provided, uses a sliding window
248                       approach for more robust dead neuron detection.
249
250        Returns:
251            Number of neurons resampled
252        """
253        with torch.no_grad():
254            # Get feature activations
255            features = self.encode(activations)
256
257            # Find dead neurons
258            if dead_mask is None:
259                neuron_activity = features.abs().sum(dim=0)
260                dead_mask = neuron_activity < dead_threshold
261            n_dead = dead_mask.sum().item()
262
263            if n_dead == 0:
264                return 0
265
266            logger.info(f"Resampling {n_dead} dead neurons")
267
268            # Compute reconstruction loss for each sample
269            reconstructed, _ = self.forward(activations)
270            sample_losses = F.mse_loss(
271                reconstructed, activations, reduction="none"
272            ).sum(dim=1)
273
274            # Sample based on loss (higher loss = more likely to be resampled)
275            probs = sample_losses**2
276            probs = probs / probs.sum()
277
278            # For each dead neuron, resample from a high-loss example
279            for dead_idx in torch.where(dead_mask)[0]:
280                # Sample an example
281                sample_idx = torch.multinomial(probs, 1).item()
282                example = activations[sample_idx]
283
284                # Normalize example
285                example_normalized = example / (example.norm() + 1e-8)
286
287                # Set decoder weight
288                self.decoder.weight[:, dead_idx] = example_normalized
289
290                # Set encoder weight (smaller scale to prevent immediate firing)
291                avg_encoder_norm = (
292                    self.encoder.weight[~dead_mask].norm(dim=1).mean().item()
293                    if (~dead_mask).any() > 0
294                    else 0.1
295                )
296                self.encoder.weight[dead_idx, :] = (
297                    example_normalized * avg_encoder_norm * 0.2
298                )
299
300                # Reset encoder bias
301                self.encoder_bias[dead_idx] = 0
302
303            # Re-normalize decoder
304            self.normalize_decoder_weights()
305
306        return n_dead
307
308    def fit(
309        self,
310        activations: np.ndarray,
311        batch_size: int = 256,
312        num_epochs: int = 100,
313        learning_rate: float = 1e-4,
314        validation_split: float = 0.1,
315        resample_dead_neurons: bool = True,
316        resample_interval: int = 10000,
317        dead_threshold: float = 1e-8,
318        window_size: int = 100,
319        device: Optional[str] = None,
320        verbose: bool = True,
321        wandb_config: Optional[WandbConfig] = None,
322        wandb_enabled: bool = True,
323    ) -> "SparseAutoencoder":
324        """
325        Train the sparse autoencoder on MLP activations.
326
327        Args:
328            activations: Training data of shape (n_samples, activation_dim)
329            batch_size: Batch size for training
330            num_epochs: Number of training epochs
331            learning_rate: Learning rate for Adam optimizer
332            validation_split: Fraction of data to use for validation
333            resample_dead_neurons: Whether to resample dead neurons during training
334            resample_interval: Steps between resampling checks
335            dead_threshold: Activation threshold below which a neuron is considered dead
336            window_size: Number of recent batches to track for dead neuron detection
337            device: Device to train on (cuda/cpu). If None, auto-detect
338            verbose: Whether to show progress bars
339            wandb_config: Optional WandbConfig for experiment tracking
340            wandb_enabled: If True and wandb_config is None, create default WandbConfig
341
342        Returns:
343            self (trained model)
344        """
345        # Setup wandb if enabled
346        wandb_logger = None
347        _owns_wandb = False
348        if wandb_enabled:
349            if wandb_config is None:
350                wandb_logger = WandbConfig(
351                    config={
352                        "activation_dim": self.activation_dim,
353                        "hidden_dim": self.hidden_dim,
354                        "expansion_factor": self.hidden_dim / self.activation_dim,
355                        "l1_coefficient": self.l1_coefficient,
356                        "batch_size": batch_size,
357                        "learning_rate": learning_rate,
358                        "num_epochs": num_epochs,
359                    }
360                )
361                wandb_logger.initialize()
362                _owns_wandb = True
363            elif not wandb_config._initialized:
364                wandb_logger = wandb_config
365                wandb_logger.initialize()
366                _owns_wandb = True
367            else:
368                wandb_logger = wandb_config
369
370        # Determine device
371        if device is None:
372            device = "cuda" if torch.cuda.is_available() else "cpu"
373
374        self.to(device)
375
376        # Convert to tensor
377        if isinstance(activations, np.ndarray):
378            activations = torch.from_numpy(activations).float()
379
380        n_samples = len(activations)
381        n_val = int(n_samples * validation_split)
382        n_train = n_samples - n_val
383
384        # Split data
385        train_data = activations[:n_train]
386        val_data = activations[n_train:]
387
388        # Create data loaders
389        train_dataset = TensorDataset(train_data)
390        train_loader = DataLoader(
391            train_dataset,
392            batch_size=batch_size,
393            shuffle=True,
394        )
395
396        # Initialize bias to median if using
397        if self.bias is not None and self.pre_encoder_bias:
398            with torch.no_grad():
399                # Approximate geometric median using median for simplicity
400                median = train_data.to(device).median(dim=0).values
401                self.bias.data = median
402
403        # Optimizer
404        optimizer = torch.optim.Adam(
405            self.parameters(),
406            lr=learning_rate,
407        )
408
409        # Training loop
410        self.training_losses = []
411        self.training_l0_norms = []
412
413        n_steps_since_resample = 0
414        global_step = 0
415
416        # Sliding window for tracking neuron activity across batches
417        neuron_activity_window = []
418
419        if verbose:
420            pbar = tqdm(total=num_epochs, desc="Training SAE")
421        else:
422            pbar = None
423
424        for epoch in range(num_epochs):
425            epoch_loss = 0.0
426            epoch_l0 = 0.0
427
428            for batch in train_loader:
429                batch = batch[0].to(device)
430
431                # Forward pass
432                reconstructed, features = self.forward(batch)
433
434                # Track neuron activity for sliding window dead detection
435                neuron_activity = features.abs().sum(dim=0)
436                neuron_activity_window.append(neuron_activity)
437                if len(neuron_activity_window) > window_size:
438                    neuron_activity_window.pop(0)
439
440                # Compute loss
441                loss = self.loss(batch, reconstructed, features)
442
443                # Backward pass
444                optimizer.zero_grad()
445                loss.backward()
446
447                # Remove gradients parallel to decoder weights (Adam optimization trick from paper)
448                if self.normalize_decoder:
449                    with torch.no_grad():
450                        # Project gradients to be orthogonal to decoder weights
451                        decoder_w = self.decoder.weight
452                        decoder_grad = self.decoder.weight.grad
453
454                        # Compute projection
455                        projection = (decoder_grad * decoder_w).sum(
456                            dim=0, keepdim=True
457                        ) * decoder_w
458                        self.decoder.weight.grad = decoder_grad - projection
459
460                optimizer.step()
461
462                # Normalize decoder weights
463                self.normalize_decoder_weights()
464
465                # Track metrics
466                epoch_loss += loss.item()
467                epoch_l0 += (features > 0).sum(dim=1).float().mean().item()
468
469                n_steps_since_resample += 1
470                global_step += 1
471
472                # Resample dead neurons using sliding window
473                if (
474                    resample_dead_neurons
475                    and n_steps_since_resample >= resample_interval
476                ):
477                    if len(neuron_activity_window) > 0:
478                        avg_activity = torch.stack(neuron_activity_window).mean(dim=0)
479                        window_dead_mask = avg_activity < dead_threshold
480                    else:
481                        window_dead_mask = None
482
483                    n_resampled = self.resample_dead_neurons(
484                        batch, dead_threshold=dead_threshold, dead_mask=window_dead_mask
485                    )
486                    if n_resampled > 0:
487                        neuron_activity_window.clear()
488                    n_steps_since_resample = 0
489
490            # Average metrics
491            avg_loss = epoch_loss / len(train_loader)
492            avg_l0 = epoch_l0 / len(train_loader)
493
494            self.training_losses.append(avg_loss)
495            self.training_l0_norms.append(avg_l0)
496
497            # Validation
498            with torch.no_grad():
499                val_data_tensor = val_data.to(device)
500                val_reconstructed, val_features = self.forward(val_data_tensor)
501                val_loss = F.mse_loss(val_reconstructed, val_data_tensor).item()
502                val_l0 = (val_features > 0).sum(dim=1).float().mean().item()
503
504            # Log to wandb
505            if wandb_logger:
506                wandb_logger.log_metrics(
507                    {
508                        "epoch": epoch + 1,
509                        "train_loss": avg_loss,
510                        "train_l0_norm": avg_l0,
511                        "val_loss": val_loss,
512                        "val_l0_norm": val_l0,
513                        "learning_rate": learning_rate,
514                    },
515                    step=epoch + 1,
516                )
517
518            if verbose:
519                if pbar:
520                    pbar.update(1)
521                    pbar.set_postfix(
522                        {
523                            "loss": f"{avg_loss:.6f}",
524                            "val_loss": f"{val_loss:.6f}",
525                            "L0": f"{avg_l0:.2f}",
526                            "val_L0": f"{val_l0:.2f}",
527                        }
528                    )
529                else:
530                    logger.info(
531                        f"Epoch {epoch + 1}/{num_epochs}: "
532                        f"loss={avg_loss:.6f}, val_loss={val_loss:.6f}, "
533                        f"L0={avg_l0:.2f}, val_L0={val_l0:.2f}"
534                    )
535
536        if pbar:
537            pbar.close()
538
539        # Finalize wandb and log summary metrics
540        if wandb_logger:
541            # Log final feature densities
542            densities = self.get_feature_density(activations.cpu().numpy())
543            n_dead = (densities == 0).sum()
544            n_active = (densities > 0).sum()
545
546            wandb_logger.log_metrics(
547                {
548                    "final_train_loss": self.training_losses[-1],
549                    "final_val_loss": val_loss,
550                    "final_l0_norm": self.training_l0_norms[-1],
551                    "final_val_l0_norm": val_l0,
552                    "n_dead_features": int(n_dead),
553                    "n_active_features": int(n_active),
554                    "feature_sparsity": float(n_active / len(densities)),
555                }
556            )
557
558            # Log feature density histogram
559            wandb_logger.log_histogram(densities, "feature_density_histogram")
560
561            logger.info(f"wandb run URL: {wandb_logger.get_run_url()}")
562
563            # Finalize wandb (close the run) only if we created it
564            if _owns_wandb:
565                wandb_logger.finalize()
566
567        logger.info("Training complete!")
568        return self
569
570    def get_feature_density(self, activations: np.ndarray) -> np.ndarray:
571        """
572        Compute the density (fraction of non-zero activations) for each feature.
573
574        Args:
575            activations: Data to compute density on
576
577        Returns:
578            Array of feature densities of shape (hidden_dim,)
579        """
580        self.eval()
581        device = next(self.parameters()).device
582        with torch.no_grad():
583            if isinstance(activations, np.ndarray):
584                activations = torch.from_numpy(activations).float().to(device)
585            else:
586                activations = activations.to(device)
587
588            features = self.encode(activations)
589            density = (features > 0).float().mean(dim=0).cpu().numpy()
590
591        return density
592
593    def get_top_activating_examples(
594        self,
595        activations: np.ndarray,
596        feature_idx: int,
597        k: int = 10,
598    ) -> Tuple[np.ndarray, np.ndarray]:
599        """
600        Get the top k examples that activate a given feature the most.
601
602        Args:
603            activations: Data to search through
604            feature_idx: Index of the feature to analyze
605            k: Number of top examples to return
606
607        Returns:
608            Tuple of (top activations values, top indices)
609        """
610        self.eval()
611        device = next(self.parameters()).device
612        with torch.no_grad():
613            if isinstance(activations, np.ndarray):
614                activations = torch.from_numpy(activations).float().to(device)
615            else:
616                activations = activations.to(device)
617
618            features = self.encode(activations)
619            feature_activations = features[:, feature_idx].cpu().numpy()
620
621            top_k_indices = np.argsort(feature_activations)[-k:][::-1]
622            top_k_values = feature_activations[top_k_indices]
623
624        return top_k_values, top_k_indices
625
626    def save(self, filepath: Union[str, Path]) -> None:
627        """Save model state, hyperparameters, and training history to disk.
628
629        Saves a dictionary containing the model architecture parameters,
630        state dict, and training loss/L0 history as a PyTorch ``.pt`` file.
631        The file can be loaded with ``SparseAutoencoder.load()``.
632
633        Args:
634            filepath: Path where the model file will be saved. Parent
635                directories are created automatically if they don't exist.
636        """
637        filepath = Path(filepath)
638        filepath.parent.mkdir(parents=True, exist_ok=True)
639
640        state = {
641            "activation_dim": self.activation_dim,
642            "hidden_dim": self.hidden_dim,
643            "l1_coefficient": self.l1_coefficient,
644            "state_dict": self.state_dict(),
645            "training_losses": self.training_losses,
646            "training_l0_norms": self.training_l0_norms,
647        }
648
649        torch.save(state, filepath)
650        logger.info(f"Saved model to {filepath}")
651
652    @classmethod
653    def load(cls, filepath: Union[str, Path]) -> "SparseAutoencoder":
654        """Load a saved SparseAutoencoder from disk.
655
656        Reconstructs the model from a ``.pt`` file saved by ``save()``,
657        restoring the architecture, learned weights, and training
658        history (losses and L0 norms).
659
660        Args:
661            filepath: Path to the saved ``.pt`` model file.
662
663        Returns:
664            A ``SparseAutoencoder`` instance with restored weights and
665            training metrics.
666        """
667        filepath = Path(filepath)
668        state = torch.load(filepath, weights_only=False)
669
670        model = cls(
671            activation_dim=state["activation_dim"],
672            hidden_dim=state["hidden_dim"],
673            l1_coefficient=state["l1_coefficient"],
674        )
675        model.load_state_dict(state["state_dict"])
676        model.training_losses = state.get("training_losses", [])
677        model.training_l0_norms = state.get("training_l0_norms", [])
678
679        logger.info(f"Loaded model from {filepath}")
680        return model

Sparse Autoencoder for extracting interpretable features from MLP activations.

Architecture:

Input (activation_dim) -> [bias] -> Encoder (hidden_dim) -> ReLU -> [bias] -> Decoder (activation_dim) -> Output

The decoder weights are normalized to unit norm to prevent scaling collapse. L1 regularization on hidden activations encourages sparsity.

Based on the architecture described in: "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning" https://transformer-circuits.pub/2023/monosemantic-features/index.html

Attributes:
  • encoder: Linear layer from activation_dim to hidden_dim
  • decoder: Linear layer from hidden_dim to activation_dim
  • pre_encoder_bias: Learnable bias subtracted from input
  • post_decoder_bias: Learnable bias added to output (tied to pre_encoder_bias)
Example:
sae = SparseAutoencoder(
    activation_dim=2048,
    hidden_dim=4096,  # 2x expansion
    l1_coefficient=0.01
)
sae.fit(activations, num_epochs=100)
features = sae.encode(activations)
reconstructed = sae.decode(features)
SparseAutoencoder( activation_dim: int, hidden_dim: int, l1_coefficient: float = 0.01, normalize_decoder: bool = True, pre_encoder_bias: bool = True)
 81    def __init__(
 82        self,
 83        activation_dim: int,
 84        hidden_dim: int,
 85        l1_coefficient: float = 0.01,
 86        normalize_decoder: bool = True,
 87        pre_encoder_bias: bool = True,
 88    ):
 89        """
 90        Initialize the Sparse Autoencoder.
 91
 92        Args:
 93            activation_dim: Dimension of input MLP activations
 94            hidden_dim: Dimension of sparse hidden layer (overcomplete)
 95            l1_coefficient: L1 regularization strength
 96            normalize_decoder: Whether to normalize decoder columns to unit norm
 97            pre_encoder_bias: Whether to use pre-encoder bias (as in the paper)
 98        """
 99        super().__init__()
100
101        self.activation_dim = activation_dim
102        self.hidden_dim = hidden_dim
103        self.l1_coefficient = l1_coefficient
104        self.normalize_decoder = normalize_decoder
105        self.pre_encoder_bias = pre_encoder_bias
106
107        # Encoder: input -> hidden
108        self.encoder = nn.Linear(activation_dim, hidden_dim, bias=False)
109
110        # Decoder: hidden -> output
111        # Initialize with small random weights
112        self.decoder = nn.Linear(hidden_dim, activation_dim, bias=False)
113
114        # Initialize decoder weights to unit norm (Kaiming uniform)
115        nn.init.kaiming_uniform_(self.decoder.weight, a=np.sqrt(5))
116        with torch.no_grad():
117            self.decoder.weight.copy_(
118                self.decoder.weight / self.decoder.weight.norm(dim=0, keepdim=True)
119            )
120
121        # Initialize encoder weights
122        nn.init.kaiming_uniform_(self.encoder.weight, a=np.sqrt(5))
123
124        # Biases
125        if pre_encoder_bias:
126            # Pre-encoder bias (subtracted from input, added to output)
127            # Initialize to geometric median of data (will be set during fit)
128            self.bias = nn.Parameter(torch.zeros(activation_dim))
129        else:
130            self.register_parameter("bias", None)
131
132        # Encoder bias
133        self.encoder_bias = nn.Parameter(torch.zeros(hidden_dim))
134
135        # Training metrics
136        self.training_losses = []
137        self.training_l0_norms = []

Initialize the Sparse Autoencoder.

Arguments:
  • activation_dim: Dimension of input MLP activations
  • hidden_dim: Dimension of sparse hidden layer (overcomplete)
  • l1_coefficient: L1 regularization strength
  • normalize_decoder: Whether to normalize decoder columns to unit norm
  • pre_encoder_bias: Whether to use pre-encoder bias (as in the paper)
activation_dim
hidden_dim
l1_coefficient
normalize_decoder
pre_encoder_bias
encoder
decoder
encoder_bias
training_losses
training_l0_norms
def encode(self, x: torch.Tensor) -> torch.Tensor:
139    def encode(self, x: torch.Tensor) -> torch.Tensor:
140        """
141        Encode input activations to sparse features.
142
143        Args:
144            x: Input tensor of shape (batch_size, activation_dim)
145
146        Returns:
147            Sparse feature activations of shape (batch_size, hidden_dim)
148        """
149        # Subtract bias if using pre-encoder bias
150        if self.bias is not None:
151            x = x - self.bias
152
153        # Apply encoder
154        hidden = F.linear(x, self.encoder.weight, self.encoder_bias)
155
156        # ReLU activation for sparsity
157        features = F.relu(hidden)
158
159        return features

Encode input activations to sparse features.

Arguments:
  • x: Input tensor of shape (batch_size, activation_dim)
Returns:

Sparse feature activations of shape (batch_size, hidden_dim)

def decode(self, features: torch.Tensor) -> torch.Tensor:
161    def decode(self, features: torch.Tensor) -> torch.Tensor:
162        """
163        Decode sparse features back to activation space.
164
165        Args:
166            features: Sparse feature tensor of shape (batch_size, hidden_dim)
167
168        Returns:
169            Reconstructed activations of shape (batch_size, activation_dim)
170        """
171        # Apply decoder (weights are normalized to unit norm)
172        output = F.linear(features, self.decoder.weight)
173
174        # Add bias if using
175        if self.bias is not None:
176            output = output + self.bias
177
178        return output

Decode sparse features back to activation space.

Arguments:
  • features: Sparse feature tensor of shape (batch_size, hidden_dim)
Returns:

Reconstructed activations of shape (batch_size, activation_dim)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
180    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
181        """
182        Forward pass through the autoencoder.
183
184        Args:
185            x: Input tensor of shape (batch_size, activation_dim)
186
187        Returns:
188            Tuple of (reconstructed, features)
189        """
190        features = self.encode(x)
191        reconstructed = self.decode(features)
192        return reconstructed, features

Forward pass through the autoencoder.

Arguments:
  • x: Input tensor of shape (batch_size, activation_dim)
Returns:

Tuple of (reconstructed, features)

def loss( self, x: torch.Tensor, reconstructed: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
194    def loss(
195        self, x: torch.Tensor, reconstructed: torch.Tensor, features: torch.Tensor
196    ) -> torch.Tensor:
197        """
198        Compute the loss with L1 regularization.
199
200        Loss = MSE(reconstructed, x) + lambda * L1(features)
201
202        Args:
203            x: Input tensor
204            reconstructed: Reconstructed tensor
205            features: Sparse feature activations
206
207        Returns:
208            Total loss
209        """
210        mse_loss = F.mse_loss(reconstructed, x)
211        l1_loss = self.l1_coefficient * features.abs().sum(dim=1).mean()
212        return mse_loss + l1_loss

Compute the loss with L1 regularization.

Loss = MSE(reconstructed, x) + lambda * L1(features)

Arguments:
  • x: Input tensor
  • reconstructed: Reconstructed tensor
  • features: Sparse feature activations
Returns:

Total loss

def normalize_decoder_weights(self) -> None:
214    def normalize_decoder_weights(self) -> None:
215        """Normalize decoder weight columns to unit L2 norm.
216
217        This prevents scaling collapse where the optimizer could
218        trivially reduce the loss by shrinking decoder weights and
219        scaling up encoder weights. Applied after each optimizer step
220        when ``normalize_decoder`` is ``True``.
221        """
222        if self.normalize_decoder:
223            with torch.no_grad():
224                self.decoder.weight.copy_(
225                    self.decoder.weight
226                    / self.decoder.weight.norm(dim=0, keepdim=True).clamp(min=1e-8)
227                )

Normalize decoder weight columns to unit L2 norm.

This prevents scaling collapse where the optimizer could trivially reduce the loss by shrinking decoder weights and scaling up encoder weights. Applied after each optimizer step when normalize_decoder is True.

def resample_dead_neurons( self, activations: torch.Tensor, dead_threshold: float = 1e-08, dead_mask: Optional[torch.Tensor] = None) -> int:
229    def resample_dead_neurons(
230        self,
231        activations: torch.Tensor,
232        dead_threshold: float = 1e-8,
233        dead_mask: Optional[torch.Tensor] = None,
234    ) -> int:
235        """
236        Resample dead neurons that haven't fired recently.
237
238        This follows the procedure from the paper:
239        1. Identify neurons that haven't fired above threshold
240        2. Compute loss on a batch of data
241        3. Resample dead neurons to fit poorly reconstructed examples
242
243        Args:
244            activations: Batch of activations to use for resampling
245            dead_threshold: Activation threshold below which a neuron is considered dead
246            dead_mask: Pre-computed boolean mask of dead neurons. If None, computed
247                       from the current batch. When provided, uses a sliding window
248                       approach for more robust dead neuron detection.
249
250        Returns:
251            Number of neurons resampled
252        """
253        with torch.no_grad():
254            # Get feature activations
255            features = self.encode(activations)
256
257            # Find dead neurons
258            if dead_mask is None:
259                neuron_activity = features.abs().sum(dim=0)
260                dead_mask = neuron_activity < dead_threshold
261            n_dead = dead_mask.sum().item()
262
263            if n_dead == 0:
264                return 0
265
266            logger.info(f"Resampling {n_dead} dead neurons")
267
268            # Compute reconstruction loss for each sample
269            reconstructed, _ = self.forward(activations)
270            sample_losses = F.mse_loss(
271                reconstructed, activations, reduction="none"
272            ).sum(dim=1)
273
274            # Sample based on loss (higher loss = more likely to be resampled)
275            probs = sample_losses**2
276            probs = probs / probs.sum()
277
278            # For each dead neuron, resample from a high-loss example
279            for dead_idx in torch.where(dead_mask)[0]:
280                # Sample an example
281                sample_idx = torch.multinomial(probs, 1).item()
282                example = activations[sample_idx]
283
284                # Normalize example
285                example_normalized = example / (example.norm() + 1e-8)
286
287                # Set decoder weight
288                self.decoder.weight[:, dead_idx] = example_normalized
289
290                # Set encoder weight (smaller scale to prevent immediate firing)
291                avg_encoder_norm = (
292                    self.encoder.weight[~dead_mask].norm(dim=1).mean().item()
293                    if (~dead_mask).any() > 0
294                    else 0.1
295                )
296                self.encoder.weight[dead_idx, :] = (
297                    example_normalized * avg_encoder_norm * 0.2
298                )
299
300                # Reset encoder bias
301                self.encoder_bias[dead_idx] = 0
302
303            # Re-normalize decoder
304            self.normalize_decoder_weights()
305
306        return n_dead

Resample dead neurons that haven't fired recently.

This follows the procedure from the paper:

  1. Identify neurons that haven't fired above threshold
  2. Compute loss on a batch of data
  3. Resample dead neurons to fit poorly reconstructed examples
Arguments:
  • activations: Batch of activations to use for resampling
  • dead_threshold: Activation threshold below which a neuron is considered dead
  • dead_mask: Pre-computed boolean mask of dead neurons. If None, computed from the current batch. When provided, uses a sliding window approach for more robust dead neuron detection.
Returns:

Number of neurons resampled

def fit( self, activations: numpy.ndarray, batch_size: int = 256, num_epochs: int = 100, learning_rate: float = 0.0001, validation_split: float = 0.1, resample_dead_neurons: bool = True, resample_interval: int = 10000, dead_threshold: float = 1e-08, window_size: int = 100, device: Optional[str] = None, verbose: bool = True, wandb_config: Optional[WandbConfig] = None, wandb_enabled: bool = True) -> SparseAutoencoder:
308    def fit(
309        self,
310        activations: np.ndarray,
311        batch_size: int = 256,
312        num_epochs: int = 100,
313        learning_rate: float = 1e-4,
314        validation_split: float = 0.1,
315        resample_dead_neurons: bool = True,
316        resample_interval: int = 10000,
317        dead_threshold: float = 1e-8,
318        window_size: int = 100,
319        device: Optional[str] = None,
320        verbose: bool = True,
321        wandb_config: Optional[WandbConfig] = None,
322        wandb_enabled: bool = True,
323    ) -> "SparseAutoencoder":
324        """
325        Train the sparse autoencoder on MLP activations.
326
327        Args:
328            activations: Training data of shape (n_samples, activation_dim)
329            batch_size: Batch size for training
330            num_epochs: Number of training epochs
331            learning_rate: Learning rate for Adam optimizer
332            validation_split: Fraction of data to use for validation
333            resample_dead_neurons: Whether to resample dead neurons during training
334            resample_interval: Steps between resampling checks
335            dead_threshold: Activation threshold below which a neuron is considered dead
336            window_size: Number of recent batches to track for dead neuron detection
337            device: Device to train on (cuda/cpu). If None, auto-detect
338            verbose: Whether to show progress bars
339            wandb_config: Optional WandbConfig for experiment tracking
340            wandb_enabled: If True and wandb_config is None, create default WandbConfig
341
342        Returns:
343            self (trained model)
344        """
345        # Setup wandb if enabled
346        wandb_logger = None
347        _owns_wandb = False
348        if wandb_enabled:
349            if wandb_config is None:
350                wandb_logger = WandbConfig(
351                    config={
352                        "activation_dim": self.activation_dim,
353                        "hidden_dim": self.hidden_dim,
354                        "expansion_factor": self.hidden_dim / self.activation_dim,
355                        "l1_coefficient": self.l1_coefficient,
356                        "batch_size": batch_size,
357                        "learning_rate": learning_rate,
358                        "num_epochs": num_epochs,
359                    }
360                )
361                wandb_logger.initialize()
362                _owns_wandb = True
363            elif not wandb_config._initialized:
364                wandb_logger = wandb_config
365                wandb_logger.initialize()
366                _owns_wandb = True
367            else:
368                wandb_logger = wandb_config
369
370        # Determine device
371        if device is None:
372            device = "cuda" if torch.cuda.is_available() else "cpu"
373
374        self.to(device)
375
376        # Convert to tensor
377        if isinstance(activations, np.ndarray):
378            activations = torch.from_numpy(activations).float()
379
380        n_samples = len(activations)
381        n_val = int(n_samples * validation_split)
382        n_train = n_samples - n_val
383
384        # Split data
385        train_data = activations[:n_train]
386        val_data = activations[n_train:]
387
388        # Create data loaders
389        train_dataset = TensorDataset(train_data)
390        train_loader = DataLoader(
391            train_dataset,
392            batch_size=batch_size,
393            shuffle=True,
394        )
395
396        # Initialize bias to median if using
397        if self.bias is not None and self.pre_encoder_bias:
398            with torch.no_grad():
399                # Approximate geometric median using median for simplicity
400                median = train_data.to(device).median(dim=0).values
401                self.bias.data = median
402
403        # Optimizer
404        optimizer = torch.optim.Adam(
405            self.parameters(),
406            lr=learning_rate,
407        )
408
409        # Training loop
410        self.training_losses = []
411        self.training_l0_norms = []
412
413        n_steps_since_resample = 0
414        global_step = 0
415
416        # Sliding window for tracking neuron activity across batches
417        neuron_activity_window = []
418
419        if verbose:
420            pbar = tqdm(total=num_epochs, desc="Training SAE")
421        else:
422            pbar = None
423
424        for epoch in range(num_epochs):
425            epoch_loss = 0.0
426            epoch_l0 = 0.0
427
428            for batch in train_loader:
429                batch = batch[0].to(device)
430
431                # Forward pass
432                reconstructed, features = self.forward(batch)
433
434                # Track neuron activity for sliding window dead detection
435                neuron_activity = features.abs().sum(dim=0)
436                neuron_activity_window.append(neuron_activity)
437                if len(neuron_activity_window) > window_size:
438                    neuron_activity_window.pop(0)
439
440                # Compute loss
441                loss = self.loss(batch, reconstructed, features)
442
443                # Backward pass
444                optimizer.zero_grad()
445                loss.backward()
446
447                # Remove gradients parallel to decoder weights (Adam optimization trick from paper)
448                if self.normalize_decoder:
449                    with torch.no_grad():
450                        # Project gradients to be orthogonal to decoder weights
451                        decoder_w = self.decoder.weight
452                        decoder_grad = self.decoder.weight.grad
453
454                        # Compute projection
455                        projection = (decoder_grad * decoder_w).sum(
456                            dim=0, keepdim=True
457                        ) * decoder_w
458                        self.decoder.weight.grad = decoder_grad - projection
459
460                optimizer.step()
461
462                # Normalize decoder weights
463                self.normalize_decoder_weights()
464
465                # Track metrics
466                epoch_loss += loss.item()
467                epoch_l0 += (features > 0).sum(dim=1).float().mean().item()
468
469                n_steps_since_resample += 1
470                global_step += 1
471
472                # Resample dead neurons using sliding window
473                if (
474                    resample_dead_neurons
475                    and n_steps_since_resample >= resample_interval
476                ):
477                    if len(neuron_activity_window) > 0:
478                        avg_activity = torch.stack(neuron_activity_window).mean(dim=0)
479                        window_dead_mask = avg_activity < dead_threshold
480                    else:
481                        window_dead_mask = None
482
483                    n_resampled = self.resample_dead_neurons(
484                        batch, dead_threshold=dead_threshold, dead_mask=window_dead_mask
485                    )
486                    if n_resampled > 0:
487                        neuron_activity_window.clear()
488                    n_steps_since_resample = 0
489
490            # Average metrics
491            avg_loss = epoch_loss / len(train_loader)
492            avg_l0 = epoch_l0 / len(train_loader)
493
494            self.training_losses.append(avg_loss)
495            self.training_l0_norms.append(avg_l0)
496
497            # Validation
498            with torch.no_grad():
499                val_data_tensor = val_data.to(device)
500                val_reconstructed, val_features = self.forward(val_data_tensor)
501                val_loss = F.mse_loss(val_reconstructed, val_data_tensor).item()
502                val_l0 = (val_features > 0).sum(dim=1).float().mean().item()
503
504            # Log to wandb
505            if wandb_logger:
506                wandb_logger.log_metrics(
507                    {
508                        "epoch": epoch + 1,
509                        "train_loss": avg_loss,
510                        "train_l0_norm": avg_l0,
511                        "val_loss": val_loss,
512                        "val_l0_norm": val_l0,
513                        "learning_rate": learning_rate,
514                    },
515                    step=epoch + 1,
516                )
517
518            if verbose:
519                if pbar:
520                    pbar.update(1)
521                    pbar.set_postfix(
522                        {
523                            "loss": f"{avg_loss:.6f}",
524                            "val_loss": f"{val_loss:.6f}",
525                            "L0": f"{avg_l0:.2f}",
526                            "val_L0": f"{val_l0:.2f}",
527                        }
528                    )
529                else:
530                    logger.info(
531                        f"Epoch {epoch + 1}/{num_epochs}: "
532                        f"loss={avg_loss:.6f}, val_loss={val_loss:.6f}, "
533                        f"L0={avg_l0:.2f}, val_L0={val_l0:.2f}"
534                    )
535
536        if pbar:
537            pbar.close()
538
539        # Finalize wandb and log summary metrics
540        if wandb_logger:
541            # Log final feature densities
542            densities = self.get_feature_density(activations.cpu().numpy())
543            n_dead = (densities == 0).sum()
544            n_active = (densities > 0).sum()
545
546            wandb_logger.log_metrics(
547                {
548                    "final_train_loss": self.training_losses[-1],
549                    "final_val_loss": val_loss,
550                    "final_l0_norm": self.training_l0_norms[-1],
551                    "final_val_l0_norm": val_l0,
552                    "n_dead_features": int(n_dead),
553                    "n_active_features": int(n_active),
554                    "feature_sparsity": float(n_active / len(densities)),
555                }
556            )
557
558            # Log feature density histogram
559            wandb_logger.log_histogram(densities, "feature_density_histogram")
560
561            logger.info(f"wandb run URL: {wandb_logger.get_run_url()}")
562
563            # Finalize wandb (close the run) only if we created it
564            if _owns_wandb:
565                wandb_logger.finalize()
566
567        logger.info("Training complete!")
568        return self

Train the sparse autoencoder on MLP activations.

Arguments:
  • activations: Training data of shape (n_samples, activation_dim)
  • batch_size: Batch size for training
  • num_epochs: Number of training epochs
  • learning_rate: Learning rate for Adam optimizer
  • validation_split: Fraction of data to use for validation
  • resample_dead_neurons: Whether to resample dead neurons during training
  • resample_interval: Steps between resampling checks
  • dead_threshold: Activation threshold below which a neuron is considered dead
  • window_size: Number of recent batches to track for dead neuron detection
  • device: Device to train on (cuda/cpu). If None, auto-detect
  • verbose: Whether to show progress bars
  • wandb_config: Optional WandbConfig for experiment tracking
  • wandb_enabled: If True and wandb_config is None, create default WandbConfig
Returns:

self (trained model)

def get_feature_density(self, activations: numpy.ndarray) -> numpy.ndarray:
570    def get_feature_density(self, activations: np.ndarray) -> np.ndarray:
571        """
572        Compute the density (fraction of non-zero activations) for each feature.
573
574        Args:
575            activations: Data to compute density on
576
577        Returns:
578            Array of feature densities of shape (hidden_dim,)
579        """
580        self.eval()
581        device = next(self.parameters()).device
582        with torch.no_grad():
583            if isinstance(activations, np.ndarray):
584                activations = torch.from_numpy(activations).float().to(device)
585            else:
586                activations = activations.to(device)
587
588            features = self.encode(activations)
589            density = (features > 0).float().mean(dim=0).cpu().numpy()
590
591        return density

Compute the density (fraction of non-zero activations) for each feature.

Arguments:
  • activations: Data to compute density on
Returns:

Array of feature densities of shape (hidden_dim,)

def get_top_activating_examples( self, activations: numpy.ndarray, feature_idx: int, k: int = 10) -> Tuple[numpy.ndarray, numpy.ndarray]:
593    def get_top_activating_examples(
594        self,
595        activations: np.ndarray,
596        feature_idx: int,
597        k: int = 10,
598    ) -> Tuple[np.ndarray, np.ndarray]:
599        """
600        Get the top k examples that activate a given feature the most.
601
602        Args:
603            activations: Data to search through
604            feature_idx: Index of the feature to analyze
605            k: Number of top examples to return
606
607        Returns:
608            Tuple of (top activations values, top indices)
609        """
610        self.eval()
611        device = next(self.parameters()).device
612        with torch.no_grad():
613            if isinstance(activations, np.ndarray):
614                activations = torch.from_numpy(activations).float().to(device)
615            else:
616                activations = activations.to(device)
617
618            features = self.encode(activations)
619            feature_activations = features[:, feature_idx].cpu().numpy()
620
621            top_k_indices = np.argsort(feature_activations)[-k:][::-1]
622            top_k_values = feature_activations[top_k_indices]
623
624        return top_k_values, top_k_indices

Get the top k examples that activate a given feature the most.

Arguments:
  • activations: Data to search through
  • feature_idx: Index of the feature to analyze
  • k: Number of top examples to return
Returns:

Tuple of (top activations values, top indices)

def save(self, filepath: Union[str, pathlib.Path]) -> None:
626    def save(self, filepath: Union[str, Path]) -> None:
627        """Save model state, hyperparameters, and training history to disk.
628
629        Saves a dictionary containing the model architecture parameters,
630        state dict, and training loss/L0 history as a PyTorch ``.pt`` file.
631        The file can be loaded with ``SparseAutoencoder.load()``.
632
633        Args:
634            filepath: Path where the model file will be saved. Parent
635                directories are created automatically if they don't exist.
636        """
637        filepath = Path(filepath)
638        filepath.parent.mkdir(parents=True, exist_ok=True)
639
640        state = {
641            "activation_dim": self.activation_dim,
642            "hidden_dim": self.hidden_dim,
643            "l1_coefficient": self.l1_coefficient,
644            "state_dict": self.state_dict(),
645            "training_losses": self.training_losses,
646            "training_l0_norms": self.training_l0_norms,
647        }
648
649        torch.save(state, filepath)
650        logger.info(f"Saved model to {filepath}")

Save model state, hyperparameters, and training history to disk.

Saves a dictionary containing the model architecture parameters, state dict, and training loss/L0 history as a PyTorch .pt file. The file can be loaded with SparseAutoencoder.load().

Arguments:
  • filepath: Path where the model file will be saved. Parent directories are created automatically if they don't exist.
@classmethod
def load( cls, filepath: Union[str, pathlib.Path]) -> SparseAutoencoder:
652    @classmethod
653    def load(cls, filepath: Union[str, Path]) -> "SparseAutoencoder":
654        """Load a saved SparseAutoencoder from disk.
655
656        Reconstructs the model from a ``.pt`` file saved by ``save()``,
657        restoring the architecture, learned weights, and training
658        history (losses and L0 norms).
659
660        Args:
661            filepath: Path to the saved ``.pt`` model file.
662
663        Returns:
664            A ``SparseAutoencoder`` instance with restored weights and
665            training metrics.
666        """
667        filepath = Path(filepath)
668        state = torch.load(filepath, weights_only=False)
669
670        model = cls(
671            activation_dim=state["activation_dim"],
672            hidden_dim=state["hidden_dim"],
673            l1_coefficient=state["l1_coefficient"],
674        )
675        model.load_state_dict(state["state_dict"])
676        model.training_losses = state.get("training_losses", [])
677        model.training_l0_norms = state.get("training_l0_norms", [])
678
679        logger.info(f"Loaded model from {filepath}")
680        return model

Load a saved SparseAutoencoder from disk.

Reconstructs the model from a .pt file saved by save(), restoring the architecture, learned weights, and training history (losses and L0 norms).

Arguments:
  • filepath: Path to the saved .pt model file.
Returns:

A SparseAutoencoder instance with restored weights and training metrics.

class FeatureVisualizer:
 41class FeatureVisualizer:
 42    """
 43    Visualize sparse autoencoder features and activations.
 44
 45    This class provides methods to create various plots that help understand
 46    the learned features, similar to the visualizations in the
 47    "Towards Monosemanticity" paper.
 48
 49    Example:
 50        ```python
 51        visualizer = FeatureVisualizer(sae, activations, metadata)
 52        visualizer.plot_feature_density()
 53        visualizer.plot_top_features(n_features=10)
 54        visualizer.plot_feature_examples(feature_idx=0)
 55        ```
 56    """
 57
 58    def __init__(
 59        self,
 60        sae: SparseAutoencoder,
 61        activations: np.ndarray,
 62        metadata: Optional[Dict[str, Any]] = None,
 63        output_dir: Union[str, Path] = "./visualizations",
 64        style: str = "whitegrid",
 65        dpi: int = 150,
 66        wandb_config: Optional[WandbConfig] = None,
 67        log_to_wandb: bool = True,
 68    ):
 69        """
 70        Initialize the visualizer.
 71
 72        Args:
 73            sae: Trained sparse autoencoder
 74            activations: The MLP activations used to train the SAE
 75            metadata: Optional metadata dictionary (e.g., from ActivationExtractor)
 76            output_dir: Directory to save plots
 77            style: Seaborn style to use
 78            dpi: DPI for saved figures
 79            wandb_config: Optional WandbConfig for logging visualizations
 80            log_to_wandb: If True, log plots to wandb when wandb_config is provided
 81        """
 82        self.sae = sae
 83        self.activations = activations
 84        self.metadata = metadata or {}
 85        self.output_dir = Path(output_dir)
 86        self.dpi = dpi
 87        self.wandb_config = wandb_config
 88        self.log_to_wandb = log_to_wandb
 89
 90        # Set style
 91        sns.set_style(style)
 92        sns.set_palette("husl")
 93
 94        # Create output directory
 95        self.output_dir.mkdir(parents=True, exist_ok=True)
 96
 97        # Compute feature activations
 98        self._compute_features()
 99
100    def _log_figure_to_wandb(self, fig: plt.Figure, name: str) -> None:
101        """
102        Log a matplotlib figure to wandb.
103
104        Args:
105            fig: The matplotlib figure
106            name: Name for the plot in wandb
107        """
108        if self.wandb_config and self.log_to_wandb:
109            try:
110                import wandb
111
112                wandb.log({name: wandb.Image(fig)})
113                logger.debug(f"Logged {name} to wandb")
114            except Exception as e:
115                logger.warning(f"Failed to log {name} to wandb: {e}")
116
117    def _compute_features(self) -> None:
118        """Compute sparse feature activations from the SAE.
119
120        Runs the encoder forward pass on the stored activations to
121        produce a matrix of feature activations. The result is stored
122        in ``self.features`` as a numpy array of shape
123        ``(n_samples, hidden_dim)`` and is used by all plotting methods.
124        """
125        import torch
126
127        self.sae.eval()
128        device = next(self.sae.parameters()).device
129        with torch.no_grad():
130            activations_tensor = torch.from_numpy(self.activations).float().to(device)
131            self.features = self.sae.encode(activations_tensor).cpu().numpy()
132
133        logger.info(f"Computed features with shape: {self.features.shape}")
134
135    def plot_feature_density(
136        self,
137        bins: int = 50,
138        log_scale: bool = True,
139        save_path: Optional[str] = None,
140    ) -> plt.Figure:
141        """
142        Plot the distribution of feature densities.
143
144        Feature density is the fraction of examples on which a feature fires.
145        This histogram helps understand the sparsity distribution.
146
147        Args:
148            bins: Number of histogram bins
149            log_scale: Whether to use log scale on x-axis
150            save_path: Optional path to save the figure
151
152        Returns:
153            The matplotlib Figure object
154        """
155        # Compute feature densities
156        densities = (self.features > 0).mean(axis=0)
157
158        fig, ax = plt.subplots(figsize=(10, 6))
159
160        # Remove zero-density features (dead neurons)
161        active_densities = densities[densities > 0]
162
163        # Plot histogram on log scale
164        if log_scale:
165            bins_log = np.logspace(-8, 0, bins)
166            ax.hist(active_densities, bins=bins_log, edgecolor="black", alpha=0.7)
167            ax.set_xscale("log")
168        else:
169            ax.hist(active_densities, bins=bins, edgecolor="black", alpha=0.7)
170
171        ax.set_xlabel("Feature Density (fraction of examples)", fontsize=12)
172        ax.set_ylabel("Number of Features", fontsize=12)
173        ax.set_title("Distribution of Feature Densities", fontsize=14)
174
175        # Add statistics
176        n_dead = (densities == 0).sum()
177        n_active = len(active_densities)
178
179        stats_text = f"Dead features: {n_dead} ({n_dead / len(densities) * 100:.1f}%)\n"
180        stats_text += (
181            f"Active features: {n_active} ({n_active / len(densities) * 100:.1f}%)\n"
182        )
183        stats_text += f"Median density: {np.median(active_densities):.2e}"
184
185        ax.text(
186            0.95,
187            0.95,
188            stats_text,
189            transform=ax.transAxes,
190            verticalalignment="top",
191            horizontalalignment="right",
192            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
193        )
194
195        plt.tight_layout()
196
197        if save_path or self.output_dir:
198            path = save_path or self.output_dir / "feature_density.png"
199            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
200            logger.info(f"Saved feature density plot to {path}")
201
202        # Log to wandb
203        self._log_figure_to_wandb(fig, "feature_density")
204
205        return fig
206
207    def plot_activation_histogram(
208        self,
209        feature_idx: int,
210        bins: int = 100,
211        log_y: bool = True,
212        save_path: Optional[str] = None,
213    ) -> plt.Figure:
214        """
215        Plot activation histogram for a specific feature.
216
217        Args:
218            feature_idx: Index of the feature to visualize
219            bins: Number of histogram bins
220            log_y: Whether to use log scale on y-axis
221            save_path: Optional path to save the figure
222
223        Returns:
224            The matplotlib Figure object
225        """
226        feature_acts = self.features[:, feature_idx]
227
228        fig, ax = plt.subplots(figsize=(10, 6))
229
230        # Plot histogram
231        counts, bin_edges, patches = ax.hist(
232            feature_acts[feature_acts > 0],  # Only non-zero activations
233            bins=bins,
234            edgecolor="black",
235            alpha=0.7,
236        )
237
238        if log_y:
239            ax.set_yscale("log")
240
241        ax.set_xlabel("Activation Value", fontsize=12)
242        ax.set_ylabel("Count (log scale)" if log_y else "Count", fontsize=12)
243        ax.set_title(f"Activation Histogram for Feature {feature_idx}", fontsize=14)
244
245        # Add statistics
246        density = (feature_acts > 0).mean()
247        max_act = feature_acts.max()
248        mean_act = feature_acts[feature_acts > 0].mean() if density > 0 else 0
249
250        stats_text = f"Density: {density:.2e}\n"
251        stats_text += f"Max activation: {max_act:.4f}\n"
252        stats_text += f"Mean (active): {mean_act:.4f}"
253
254        ax.text(
255            0.95,
256            0.95,
257            stats_text,
258            transform=ax.transAxes,
259            verticalalignment="top",
260            horizontalalignment="right",
261            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.5),
262        )
263
264        plt.tight_layout()
265
266        if save_path or self.output_dir:
267            path = (
268                save_path
269                or self.output_dir / f"activation_histogram_feature_{feature_idx}.png"
270            )
271            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
272            logger.info(f"Saved activation histogram to {path}")
273
274        return fig
275
276    def plot_training_curves(
277        self,
278        save_path: Optional[str] = None,
279    ) -> plt.Figure:
280        """
281        Plot training loss and L0 norm curves.
282
283        Args:
284            save_path: Optional path to save the figure
285
286        Returns:
287            The matplotlib Figure object
288        """
289        if not self.sae.training_losses:
290            logger.warning("No training data available")
291            return None
292
293        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
294
295        epochs = range(1, len(self.sae.training_losses) + 1)
296
297        # Loss curve
298        ax1.plot(epochs, self.sae.training_losses, "b-", linewidth=2)
299        ax1.set_xlabel("Epoch", fontsize=12)
300        ax1.set_ylabel("Loss", fontsize=12)
301        ax1.set_title("Training Loss", fontsize=14)
302        ax1.grid(True, alpha=0.3)
303
304        # L0 norm curve
305        ax2.plot(epochs, self.sae.training_l0_norms, "r-", linewidth=2)
306        ax2.set_xlabel("Epoch", fontsize=12)
307        ax2.set_ylabel("L0 Norm (avg # active features)", fontsize=12)
308        ax2.set_title("Sparsity (L0 Norm)", fontsize=14)
309        ax2.grid(True, alpha=0.3)
310
311        plt.tight_layout()
312
313        if save_path or self.output_dir:
314            path = save_path or self.output_dir / "training_curves.png"
315            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
316            logger.info(f"Saved training curves to {path}")
317
318        return fig
319
320    def plot_top_features(
321        self,
322        n_features: int = 10,
323        by: str = "density",
324        save_path: Optional[str] = None,
325    ) -> plt.Figure:
326        """
327        Plot top N features by various metrics.
328
329        Args:
330            n_features: Number of top features to show
331            by: Metric to sort by ("density", "max_activation", "mean_activation")
332            save_path: Optional path to save the figure
333
334        Returns:
335            The matplotlib Figure object
336        """
337        fig, ax = plt.subplots(figsize=(12, 6))
338
339        # Compute metric for each feature
340        if by == "density":
341            values = (self.features > 0).mean(axis=0)
342            ylabel = "Feature Density"
343            title = f"Top {n_features} Features by Density"
344        elif by == "max_activation":
345            values = self.features.max(axis=0)
346            ylabel = "Max Activation"
347            title = f"Top {n_features} Features by Max Activation"
348        elif by == "mean_activation":
349            values = self.features.mean(axis=0)
350            ylabel = "Mean Activation"
351            title = f"Top {n_features} Features by Mean Activation"
352        else:
353            raise ValueError(f"Unknown metric: {by}")
354
355        # Get top features
356        top_indices = np.argsort(values)[-n_features:][::-1]
357        top_values = values[top_indices]
358
359        # Plot bar chart
360        _ = ax.barh(range(n_features), top_values[::-1])
361        ax.set_yticks(range(n_features))
362        ax.set_yticklabels([f"Feature {i}" for i in top_indices[::-1]])
363        ax.set_xlabel(ylabel, fontsize=12)
364        ax.set_title(title, fontsize=14)
365        ax.grid(True, axis="x", alpha=0.3)
366
367        plt.tight_layout()
368
369        if save_path or self.output_dir:
370            path = save_path or self.output_dir / f"top_features_{by}.png"
371            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
372            logger.info(f"Saved top features plot to {path}")
373
374        return fig
375
376    def plot_feature_examples(
377        self,
378        feature_idx: int,
379        k: int = 10,
380        show_text: bool = True,
381        max_text_length: int = 200,
382        save_path: Optional[str] = None,
383    ) -> plt.Figure:
384        """
385        Plot top k activating examples for a specific feature.
386
387        Args:
388            feature_idx: Index of the feature to visualize
389            k: Number of examples to show
390            show_text: Whether to show the text examples
391            max_text_length: Maximum length of text to display
392            save_path: Optional path to save the figure
393
394        Returns:
395            The matplotlib Figure object
396        """
397        # Get top activating examples
398        top_values, top_indices = self.sae.get_top_activating_examples(
399            self.activations, feature_idx, k
400        )
401
402        fig, ax = plt.subplots(figsize=(12, k * 0.5))
403
404        # Plot bar chart
405        _ = ax.barh(range(k), top_values[::-1])
406        ax.set_yticks(range(k))
407        ax.set_xlabel("Activation Value", fontsize=12)
408        ax.set_title(
409            f"Top {k} Activating Examples for Feature {feature_idx}", fontsize=14
410        )
411        ax.grid(True, axis="x", alpha=0.3)
412
413        # Add text examples if available
414        if show_text and "samples_metadata" in self.metadata:
415            samples_metadata = self.metadata["samples_metadata"]
416
417            labels = []
418            for i, idx in enumerate(top_indices[::-1]):
419                if idx < len(samples_metadata):
420                    text = samples_metadata[idx].get("text", "")
421                    # Clean up text
422                    text = text.replace("\n", " ").strip()
423                    if len(text) > max_text_length:
424                        text = text[:max_text_length] + "..."
425                    labels.append(f"{i + 1}. {text}")
426                else:
427                    labels.append(f"{i + 1}. (No metadata)")
428
429            ax.set_yticklabels(labels, fontsize=9)
430
431        plt.tight_layout()
432
433        if save_path or self.output_dir:
434            path = save_path or self.output_dir / f"feature_{feature_idx}_examples.png"
435            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
436            logger.info(f"Saved feature examples to {path}")
437
438        return fig
439
440    def plot_decoder_weights(
441        self,
442        feature_indices: Optional[List[int]] = None,
443        n_features: int = 10,
444        save_path: Optional[str] = None,
445    ) -> plt.Figure:
446        """
447        Visualize decoder weights for specific features.
448
449        This shows how each feature maps back to the MLP activation space.
450
451        Args:
452            feature_indices: Specific feature indices to plot
453            n_features: Number of features to plot (if feature_indices is None)
454            save_path: Optional path to save the figure
455
456        Returns:
457            The matplotlib Figure object
458        """
459        import torch
460
461        if feature_indices is None:
462            # Get features by density
463            densities = (self.features > 0).mean(axis=0)
464            feature_indices = np.argsort(densities)[-n_features:][::-1]
465
466        n_features_plot = len(feature_indices)
467        fig, axes = plt.subplots(1, n_features_plot, figsize=(3 * n_features_plot, 4))
468
469        if n_features_plot == 1:
470            axes = [axes]
471
472        with torch.no_grad():
473            for i, feat_idx in enumerate(feature_indices):
474                decoder_weight = self.sae.decoder.weight[:, feat_idx].cpu().numpy()
475
476                axes[i].hist(decoder_weight, bins=50, edgecolor="black", alpha=0.7)
477                axes[i].set_title(f"Feature {feat_idx}", fontsize=12)
478                axes[i].set_xlabel("Weight Value", fontsize=10)
479                axes[i].set_ylabel("Count", fontsize=10)
480                axes[i].grid(True, alpha=0.3)
481
482        plt.suptitle("Decoder Weight Distributions", fontsize=14)
483        plt.tight_layout()
484
485        if save_path or self.output_dir:
486            path = save_path or self.output_dir / "decoder_weights.png"
487            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
488            logger.info(f"Saved decoder weights plot to {path}")
489
490        return fig
491
492    def create_feature_dashboard(
493        self,
494        feature_idx: int,
495        save_path: Optional[str] = None,
496    ) -> plt.Figure:
497        """
498        Create a comprehensive dashboard for a single feature.
499
500        Includes activation histogram, top examples, and decoder weights.
501
502        Args:
503            feature_idx: Index of the feature to visualize
504            save_path: Optional path to save the figure
505
506        Returns:
507            The matplotlib Figure object
508        """
509        fig = plt.figure(figsize=(16, 10))
510        gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3)
511
512        # 1. Activation histogram
513        ax1 = fig.add_subplot(gs[0, 0])
514        feature_acts = self.features[:, feature_idx]
515        ax1.hist(feature_acts[feature_acts > 0], bins=50, edgecolor="black", alpha=0.7)
516        ax1.set_xlabel("Activation Value", fontsize=11)
517        ax1.set_ylabel("Count", fontsize=11)
518        ax1.set_title(f"Activation Distribution (Feature {feature_idx})", fontsize=12)
519        ax1.set_yscale("log")
520
521        # 2. Top activating examples
522        ax2 = fig.add_subplot(gs[0, 1])
523        top_values, top_indices = self.sae.get_top_activating_examples(
524            self.activations, feature_idx, 10
525        )
526        ax2.barh(range(10), top_values[::-1])
527        ax2.set_xlabel("Activation Value", fontsize=11)
528        ax2.set_title("Top 10 Activating Examples", fontsize=12)
529        ax2.set_yticks(range(10))
530        ax2.grid(True, axis="x", alpha=0.3)
531
532        # 3. Decoder weights
533        ax3 = fig.add_subplot(gs[1, :])
534        import torch
535
536        with torch.no_grad():
537            decoder_weight = self.sae.decoder.weight[:, feature_idx].cpu().numpy()
538
539        ax3.hist(decoder_weight, bins=100, edgecolor="black", alpha=0.7)
540        ax3.set_xlabel("Weight Value", fontsize=11)
541        ax3.set_ylabel("Count", fontsize=11)
542        ax3.set_title("Decoder Weight Distribution", fontsize=12)
543        ax3.grid(True, alpha=0.3)
544
545        # Add statistics box
546        density = (feature_acts > 0).mean()
547        max_act = feature_acts.max()
548
549        stats_text = f"Feature {feature_idx} Statistics:\n"
550        stats_text += f"Density: {density:.2e}\n"
551        stats_text += f"Max Activation: {max_act:.4f}\n"
552        stats_text += f"Decoder Norm: {np.linalg.norm(decoder_weight):.4f}"
553
554        fig.text(
555            0.5,
556            0.02,
557            stats_text,
558            ha="center",
559            fontsize=12,
560            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
561        )
562
563        plt.suptitle(f"Feature {feature_idx} Dashboard", fontsize=16)
564
565        if save_path or self.output_dir:
566            path = save_path or self.output_dir / f"feature_{feature_idx}_dashboard.png"
567            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
568            logger.info(f"Saved feature dashboard to {path}")
569
570        return fig
571
572    def save_all(self, n_features: int = 10) -> None:
573        """
574        Generate and save all standard visualizations.
575
576        Args:
577            n_features: Number of features to include in detailed plots
578        """
579        logger.info("Generating all visualizations...")
580
581        self.plot_feature_density()
582        self.plot_training_curves()
583
584        for metric in ["density", "max_activation"]:
585            self.plot_top_features(n_features=n_features, by=metric)
586
587        # Create dashboards for top features by density
588        densities = (self.features > 0).mean(axis=0)
589        top_features = np.argsort(densities)[-n_features:]
590
591        for feat_idx in top_features:
592            self.create_feature_dashboard(feat_idx)
593
594        logger.info(f"All visualizations saved to {self.output_dir}")

Visualize sparse autoencoder features and activations.

This class provides methods to create various plots that help understand the learned features, similar to the visualizations in the "Towards Monosemanticity" paper.

Example:
visualizer = FeatureVisualizer(sae, activations, metadata)
visualizer.plot_feature_density()
visualizer.plot_top_features(n_features=10)
visualizer.plot_feature_examples(feature_idx=0)
FeatureVisualizer( sae: SparseAutoencoder, activations: numpy.ndarray, metadata: Optional[Dict[str, Any]] = None, output_dir: Union[str, pathlib.Path] = './visualizations', style: str = 'whitegrid', dpi: int = 150, wandb_config: Optional[WandbConfig] = None, log_to_wandb: bool = True)
58    def __init__(
59        self,
60        sae: SparseAutoencoder,
61        activations: np.ndarray,
62        metadata: Optional[Dict[str, Any]] = None,
63        output_dir: Union[str, Path] = "./visualizations",
64        style: str = "whitegrid",
65        dpi: int = 150,
66        wandb_config: Optional[WandbConfig] = None,
67        log_to_wandb: bool = True,
68    ):
69        """
70        Initialize the visualizer.
71
72        Args:
73            sae: Trained sparse autoencoder
74            activations: The MLP activations used to train the SAE
75            metadata: Optional metadata dictionary (e.g., from ActivationExtractor)
76            output_dir: Directory to save plots
77            style: Seaborn style to use
78            dpi: DPI for saved figures
79            wandb_config: Optional WandbConfig for logging visualizations
80            log_to_wandb: If True, log plots to wandb when wandb_config is provided
81        """
82        self.sae = sae
83        self.activations = activations
84        self.metadata = metadata or {}
85        self.output_dir = Path(output_dir)
86        self.dpi = dpi
87        self.wandb_config = wandb_config
88        self.log_to_wandb = log_to_wandb
89
90        # Set style
91        sns.set_style(style)
92        sns.set_palette("husl")
93
94        # Create output directory
95        self.output_dir.mkdir(parents=True, exist_ok=True)
96
97        # Compute feature activations
98        self._compute_features()

Initialize the visualizer.

Arguments:
  • sae: Trained sparse autoencoder
  • activations: The MLP activations used to train the SAE
  • metadata: Optional metadata dictionary (e.g., from ActivationExtractor)
  • output_dir: Directory to save plots
  • style: Seaborn style to use
  • dpi: DPI for saved figures
  • wandb_config: Optional WandbConfig for logging visualizations
  • log_to_wandb: If True, log plots to wandb when wandb_config is provided
sae
activations
metadata
output_dir
dpi
wandb_config
log_to_wandb
def plot_feature_density( self, bins: int = 50, log_scale: bool = True, save_path: Optional[str] = None) -> matplotlib.figure.Figure:
135    def plot_feature_density(
136        self,
137        bins: int = 50,
138        log_scale: bool = True,
139        save_path: Optional[str] = None,
140    ) -> plt.Figure:
141        """
142        Plot the distribution of feature densities.
143
144        Feature density is the fraction of examples on which a feature fires.
145        This histogram helps understand the sparsity distribution.
146
147        Args:
148            bins: Number of histogram bins
149            log_scale: Whether to use log scale on x-axis
150            save_path: Optional path to save the figure
151
152        Returns:
153            The matplotlib Figure object
154        """
155        # Compute feature densities
156        densities = (self.features > 0).mean(axis=0)
157
158        fig, ax = plt.subplots(figsize=(10, 6))
159
160        # Remove zero-density features (dead neurons)
161        active_densities = densities[densities > 0]
162
163        # Plot histogram on log scale
164        if log_scale:
165            bins_log = np.logspace(-8, 0, bins)
166            ax.hist(active_densities, bins=bins_log, edgecolor="black", alpha=0.7)
167            ax.set_xscale("log")
168        else:
169            ax.hist(active_densities, bins=bins, edgecolor="black", alpha=0.7)
170
171        ax.set_xlabel("Feature Density (fraction of examples)", fontsize=12)
172        ax.set_ylabel("Number of Features", fontsize=12)
173        ax.set_title("Distribution of Feature Densities", fontsize=14)
174
175        # Add statistics
176        n_dead = (densities == 0).sum()
177        n_active = len(active_densities)
178
179        stats_text = f"Dead features: {n_dead} ({n_dead / len(densities) * 100:.1f}%)\n"
180        stats_text += (
181            f"Active features: {n_active} ({n_active / len(densities) * 100:.1f}%)\n"
182        )
183        stats_text += f"Median density: {np.median(active_densities):.2e}"
184
185        ax.text(
186            0.95,
187            0.95,
188            stats_text,
189            transform=ax.transAxes,
190            verticalalignment="top",
191            horizontalalignment="right",
192            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
193        )
194
195        plt.tight_layout()
196
197        if save_path or self.output_dir:
198            path = save_path or self.output_dir / "feature_density.png"
199            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
200            logger.info(f"Saved feature density plot to {path}")
201
202        # Log to wandb
203        self._log_figure_to_wandb(fig, "feature_density")
204
205        return fig

Plot the distribution of feature densities.

Feature density is the fraction of examples on which a feature fires. This histogram helps understand the sparsity distribution.

Arguments:
  • bins: Number of histogram bins
  • log_scale: Whether to use log scale on x-axis
  • save_path: Optional path to save the figure
Returns:

The matplotlib Figure object

def plot_activation_histogram( self, feature_idx: int, bins: int = 100, log_y: bool = True, save_path: Optional[str] = None) -> matplotlib.figure.Figure:
207    def plot_activation_histogram(
208        self,
209        feature_idx: int,
210        bins: int = 100,
211        log_y: bool = True,
212        save_path: Optional[str] = None,
213    ) -> plt.Figure:
214        """
215        Plot activation histogram for a specific feature.
216
217        Args:
218            feature_idx: Index of the feature to visualize
219            bins: Number of histogram bins
220            log_y: Whether to use log scale on y-axis
221            save_path: Optional path to save the figure
222
223        Returns:
224            The matplotlib Figure object
225        """
226        feature_acts = self.features[:, feature_idx]
227
228        fig, ax = plt.subplots(figsize=(10, 6))
229
230        # Plot histogram
231        counts, bin_edges, patches = ax.hist(
232            feature_acts[feature_acts > 0],  # Only non-zero activations
233            bins=bins,
234            edgecolor="black",
235            alpha=0.7,
236        )
237
238        if log_y:
239            ax.set_yscale("log")
240
241        ax.set_xlabel("Activation Value", fontsize=12)
242        ax.set_ylabel("Count (log scale)" if log_y else "Count", fontsize=12)
243        ax.set_title(f"Activation Histogram for Feature {feature_idx}", fontsize=14)
244
245        # Add statistics
246        density = (feature_acts > 0).mean()
247        max_act = feature_acts.max()
248        mean_act = feature_acts[feature_acts > 0].mean() if density > 0 else 0
249
250        stats_text = f"Density: {density:.2e}\n"
251        stats_text += f"Max activation: {max_act:.4f}\n"
252        stats_text += f"Mean (active): {mean_act:.4f}"
253
254        ax.text(
255            0.95,
256            0.95,
257            stats_text,
258            transform=ax.transAxes,
259            verticalalignment="top",
260            horizontalalignment="right",
261            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.5),
262        )
263
264        plt.tight_layout()
265
266        if save_path or self.output_dir:
267            path = (
268                save_path
269                or self.output_dir / f"activation_histogram_feature_{feature_idx}.png"
270            )
271            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
272            logger.info(f"Saved activation histogram to {path}")
273
274        return fig

Plot activation histogram for a specific feature.

Arguments:
  • feature_idx: Index of the feature to visualize
  • bins: Number of histogram bins
  • log_y: Whether to use log scale on y-axis
  • save_path: Optional path to save the figure
Returns:

The matplotlib Figure object

def plot_training_curves(self, save_path: Optional[str] = None) -> matplotlib.figure.Figure:
276    def plot_training_curves(
277        self,
278        save_path: Optional[str] = None,
279    ) -> plt.Figure:
280        """
281        Plot training loss and L0 norm curves.
282
283        Args:
284            save_path: Optional path to save the figure
285
286        Returns:
287            The matplotlib Figure object
288        """
289        if not self.sae.training_losses:
290            logger.warning("No training data available")
291            return None
292
293        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
294
295        epochs = range(1, len(self.sae.training_losses) + 1)
296
297        # Loss curve
298        ax1.plot(epochs, self.sae.training_losses, "b-", linewidth=2)
299        ax1.set_xlabel("Epoch", fontsize=12)
300        ax1.set_ylabel("Loss", fontsize=12)
301        ax1.set_title("Training Loss", fontsize=14)
302        ax1.grid(True, alpha=0.3)
303
304        # L0 norm curve
305        ax2.plot(epochs, self.sae.training_l0_norms, "r-", linewidth=2)
306        ax2.set_xlabel("Epoch", fontsize=12)
307        ax2.set_ylabel("L0 Norm (avg # active features)", fontsize=12)
308        ax2.set_title("Sparsity (L0 Norm)", fontsize=14)
309        ax2.grid(True, alpha=0.3)
310
311        plt.tight_layout()
312
313        if save_path or self.output_dir:
314            path = save_path or self.output_dir / "training_curves.png"
315            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
316            logger.info(f"Saved training curves to {path}")
317
318        return fig

Plot training loss and L0 norm curves.

Arguments:
  • save_path: Optional path to save the figure
Returns:

The matplotlib Figure object

def plot_top_features( self, n_features: int = 10, by: str = 'density', save_path: Optional[str] = None) -> matplotlib.figure.Figure:
320    def plot_top_features(
321        self,
322        n_features: int = 10,
323        by: str = "density",
324        save_path: Optional[str] = None,
325    ) -> plt.Figure:
326        """
327        Plot top N features by various metrics.
328
329        Args:
330            n_features: Number of top features to show
331            by: Metric to sort by ("density", "max_activation", "mean_activation")
332            save_path: Optional path to save the figure
333
334        Returns:
335            The matplotlib Figure object
336        """
337        fig, ax = plt.subplots(figsize=(12, 6))
338
339        # Compute metric for each feature
340        if by == "density":
341            values = (self.features > 0).mean(axis=0)
342            ylabel = "Feature Density"
343            title = f"Top {n_features} Features by Density"
344        elif by == "max_activation":
345            values = self.features.max(axis=0)
346            ylabel = "Max Activation"
347            title = f"Top {n_features} Features by Max Activation"
348        elif by == "mean_activation":
349            values = self.features.mean(axis=0)
350            ylabel = "Mean Activation"
351            title = f"Top {n_features} Features by Mean Activation"
352        else:
353            raise ValueError(f"Unknown metric: {by}")
354
355        # Get top features
356        top_indices = np.argsort(values)[-n_features:][::-1]
357        top_values = values[top_indices]
358
359        # Plot bar chart
360        _ = ax.barh(range(n_features), top_values[::-1])
361        ax.set_yticks(range(n_features))
362        ax.set_yticklabels([f"Feature {i}" for i in top_indices[::-1]])
363        ax.set_xlabel(ylabel, fontsize=12)
364        ax.set_title(title, fontsize=14)
365        ax.grid(True, axis="x", alpha=0.3)
366
367        plt.tight_layout()
368
369        if save_path or self.output_dir:
370            path = save_path or self.output_dir / f"top_features_{by}.png"
371            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
372            logger.info(f"Saved top features plot to {path}")
373
374        return fig

Plot top N features by various metrics.

Arguments:
  • n_features: Number of top features to show
  • by: Metric to sort by ("density", "max_activation", "mean_activation")
  • save_path: Optional path to save the figure
Returns:

The matplotlib Figure object

def plot_feature_examples( self, feature_idx: int, k: int = 10, show_text: bool = True, max_text_length: int = 200, save_path: Optional[str] = None) -> matplotlib.figure.Figure:
376    def plot_feature_examples(
377        self,
378        feature_idx: int,
379        k: int = 10,
380        show_text: bool = True,
381        max_text_length: int = 200,
382        save_path: Optional[str] = None,
383    ) -> plt.Figure:
384        """
385        Plot top k activating examples for a specific feature.
386
387        Args:
388            feature_idx: Index of the feature to visualize
389            k: Number of examples to show
390            show_text: Whether to show the text examples
391            max_text_length: Maximum length of text to display
392            save_path: Optional path to save the figure
393
394        Returns:
395            The matplotlib Figure object
396        """
397        # Get top activating examples
398        top_values, top_indices = self.sae.get_top_activating_examples(
399            self.activations, feature_idx, k
400        )
401
402        fig, ax = plt.subplots(figsize=(12, k * 0.5))
403
404        # Plot bar chart
405        _ = ax.barh(range(k), top_values[::-1])
406        ax.set_yticks(range(k))
407        ax.set_xlabel("Activation Value", fontsize=12)
408        ax.set_title(
409            f"Top {k} Activating Examples for Feature {feature_idx}", fontsize=14
410        )
411        ax.grid(True, axis="x", alpha=0.3)
412
413        # Add text examples if available
414        if show_text and "samples_metadata" in self.metadata:
415            samples_metadata = self.metadata["samples_metadata"]
416
417            labels = []
418            for i, idx in enumerate(top_indices[::-1]):
419                if idx < len(samples_metadata):
420                    text = samples_metadata[idx].get("text", "")
421                    # Clean up text
422                    text = text.replace("\n", " ").strip()
423                    if len(text) > max_text_length:
424                        text = text[:max_text_length] + "..."
425                    labels.append(f"{i + 1}. {text}")
426                else:
427                    labels.append(f"{i + 1}. (No metadata)")
428
429            ax.set_yticklabels(labels, fontsize=9)
430
431        plt.tight_layout()
432
433        if save_path or self.output_dir:
434            path = save_path or self.output_dir / f"feature_{feature_idx}_examples.png"
435            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
436            logger.info(f"Saved feature examples to {path}")
437
438        return fig

Plot top k activating examples for a specific feature.

Arguments:
  • feature_idx: Index of the feature to visualize
  • k: Number of examples to show
  • show_text: Whether to show the text examples
  • max_text_length: Maximum length of text to display
  • save_path: Optional path to save the figure
Returns:

The matplotlib Figure object

def plot_decoder_weights( self, feature_indices: Optional[List[int]] = None, n_features: int = 10, save_path: Optional[str] = None) -> matplotlib.figure.Figure:
440    def plot_decoder_weights(
441        self,
442        feature_indices: Optional[List[int]] = None,
443        n_features: int = 10,
444        save_path: Optional[str] = None,
445    ) -> plt.Figure:
446        """
447        Visualize decoder weights for specific features.
448
449        This shows how each feature maps back to the MLP activation space.
450
451        Args:
452            feature_indices: Specific feature indices to plot
453            n_features: Number of features to plot (if feature_indices is None)
454            save_path: Optional path to save the figure
455
456        Returns:
457            The matplotlib Figure object
458        """
459        import torch
460
461        if feature_indices is None:
462            # Get features by density
463            densities = (self.features > 0).mean(axis=0)
464            feature_indices = np.argsort(densities)[-n_features:][::-1]
465
466        n_features_plot = len(feature_indices)
467        fig, axes = plt.subplots(1, n_features_plot, figsize=(3 * n_features_plot, 4))
468
469        if n_features_plot == 1:
470            axes = [axes]
471
472        with torch.no_grad():
473            for i, feat_idx in enumerate(feature_indices):
474                decoder_weight = self.sae.decoder.weight[:, feat_idx].cpu().numpy()
475
476                axes[i].hist(decoder_weight, bins=50, edgecolor="black", alpha=0.7)
477                axes[i].set_title(f"Feature {feat_idx}", fontsize=12)
478                axes[i].set_xlabel("Weight Value", fontsize=10)
479                axes[i].set_ylabel("Count", fontsize=10)
480                axes[i].grid(True, alpha=0.3)
481
482        plt.suptitle("Decoder Weight Distributions", fontsize=14)
483        plt.tight_layout()
484
485        if save_path or self.output_dir:
486            path = save_path or self.output_dir / "decoder_weights.png"
487            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
488            logger.info(f"Saved decoder weights plot to {path}")
489
490        return fig

Visualize decoder weights for specific features.

This shows how each feature maps back to the MLP activation space.

Arguments:
  • feature_indices: Specific feature indices to plot
  • n_features: Number of features to plot (if feature_indices is None)
  • save_path: Optional path to save the figure
Returns:

The matplotlib Figure object

def create_feature_dashboard( self, feature_idx: int, save_path: Optional[str] = None) -> matplotlib.figure.Figure:
492    def create_feature_dashboard(
493        self,
494        feature_idx: int,
495        save_path: Optional[str] = None,
496    ) -> plt.Figure:
497        """
498        Create a comprehensive dashboard for a single feature.
499
500        Includes activation histogram, top examples, and decoder weights.
501
502        Args:
503            feature_idx: Index of the feature to visualize
504            save_path: Optional path to save the figure
505
506        Returns:
507            The matplotlib Figure object
508        """
509        fig = plt.figure(figsize=(16, 10))
510        gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3)
511
512        # 1. Activation histogram
513        ax1 = fig.add_subplot(gs[0, 0])
514        feature_acts = self.features[:, feature_idx]
515        ax1.hist(feature_acts[feature_acts > 0], bins=50, edgecolor="black", alpha=0.7)
516        ax1.set_xlabel("Activation Value", fontsize=11)
517        ax1.set_ylabel("Count", fontsize=11)
518        ax1.set_title(f"Activation Distribution (Feature {feature_idx})", fontsize=12)
519        ax1.set_yscale("log")
520
521        # 2. Top activating examples
522        ax2 = fig.add_subplot(gs[0, 1])
523        top_values, top_indices = self.sae.get_top_activating_examples(
524            self.activations, feature_idx, 10
525        )
526        ax2.barh(range(10), top_values[::-1])
527        ax2.set_xlabel("Activation Value", fontsize=11)
528        ax2.set_title("Top 10 Activating Examples", fontsize=12)
529        ax2.set_yticks(range(10))
530        ax2.grid(True, axis="x", alpha=0.3)
531
532        # 3. Decoder weights
533        ax3 = fig.add_subplot(gs[1, :])
534        import torch
535
536        with torch.no_grad():
537            decoder_weight = self.sae.decoder.weight[:, feature_idx].cpu().numpy()
538
539        ax3.hist(decoder_weight, bins=100, edgecolor="black", alpha=0.7)
540        ax3.set_xlabel("Weight Value", fontsize=11)
541        ax3.set_ylabel("Count", fontsize=11)
542        ax3.set_title("Decoder Weight Distribution", fontsize=12)
543        ax3.grid(True, alpha=0.3)
544
545        # Add statistics box
546        density = (feature_acts > 0).mean()
547        max_act = feature_acts.max()
548
549        stats_text = f"Feature {feature_idx} Statistics:\n"
550        stats_text += f"Density: {density:.2e}\n"
551        stats_text += f"Max Activation: {max_act:.4f}\n"
552        stats_text += f"Decoder Norm: {np.linalg.norm(decoder_weight):.4f}"
553
554        fig.text(
555            0.5,
556            0.02,
557            stats_text,
558            ha="center",
559            fontsize=12,
560            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
561        )
562
563        plt.suptitle(f"Feature {feature_idx} Dashboard", fontsize=16)
564
565        if save_path or self.output_dir:
566            path = save_path or self.output_dir / f"feature_{feature_idx}_dashboard.png"
567            plt.savefig(path, dpi=self.dpi, bbox_inches="tight")
568            logger.info(f"Saved feature dashboard to {path}")
569
570        return fig

Create a comprehensive dashboard for a single feature.

Includes activation histogram, top examples, and decoder weights.

Arguments:
  • feature_idx: Index of the feature to visualize
  • save_path: Optional path to save the figure
Returns:

The matplotlib Figure object

def save_all(self, n_features: int = 10) -> None:
572    def save_all(self, n_features: int = 10) -> None:
573        """
574        Generate and save all standard visualizations.
575
576        Args:
577            n_features: Number of features to include in detailed plots
578        """
579        logger.info("Generating all visualizations...")
580
581        self.plot_feature_density()
582        self.plot_training_curves()
583
584        for metric in ["density", "max_activation"]:
585            self.plot_top_features(n_features=n_features, by=metric)
586
587        # Create dashboards for top features by density
588        densities = (self.features > 0).mean(axis=0)
589        top_features = np.argsort(densities)[-n_features:]
590
591        for feat_idx in top_features:
592            self.create_feature_dashboard(feat_idx)
593
594        logger.info(f"All visualizations saved to {self.output_dir}")

Generate and save all standard visualizations.

Arguments:
  • n_features: Number of features to include in detailed plots
class SAESteering:
135class SAESteering:
136    """
137    Steer language model generation using trained SAE features.
138
139    Wraps a HuggingFace language model (via nnsight) and applies SAE
140    feature interventions during generation.  At each generation step,
141    a forward hook on the target MLP module adds a scaled decoder
142    weight vector, biasing the residual stream toward the chosen
143    feature direction.
144
145    The HuggingFace auth token is read automatically from
146    ``HUGGINGFACE_HUB_TOKEN`` (via :func:`~drrik.settings.get_settings`)
147    when the *token* parameter is not explicitly provided.
148
149    Attributes:
150        sae: The trained :class:`~drrik.autoencoder.SparseAutoencoder`
151            providing feature directions.
152        model: The nnsight :class:`~nnsight.LanguageModel` wrapper.
153        tokenizer: The HuggingFace tokenizer for the model.
154        target_layer: The layer index where interventions are applied.
155        device: The device the SAE parameters live on.
156
157    Example:
158        ```python
159        sae = SparseAutoencoder.load("sae_model.pt")
160        steering = SAESteering(sae, model_name="google/gemma-2b", layer=5)
161
162        # Steered generation
163        result = steering.generate(
164            "The sky is",
165            feature_idx=128,
166            strength=2.5,
167            max_new_tokens=50,
168        )
169
170        # Baseline (no steering)
171        baseline = steering.generate("The sky is", max_new_tokens=50)
172
173        # Multi-feature steering
174        result = steering.generate(
175            "The sky is",
176            feature_indices=[10, 128],
177            strengths=[1.0, 2.5],
178        )
179        ```
180    """
181
182    def __init__(
183        self,
184        sae: SparseAutoencoder,
185        model_name: str,
186        layer: int,
187        revision: str = "main",
188        torch_dtype: str = "float16",
189        device_map: str = "auto",
190        trust_remote_code: bool = True,
191        token: Optional[str] = None,
192    ):
193        """
194        Initialize the SAE steering controller.
195
196        If ``token`` is not provided, it is read from the environment via
197        ``get_settings()`` (i.e. ``HUGGINGFACE_HUB_TOKEN`` in ``.env``).
198
199        Args:
200            sae: A trained SparseAutoencoder whose decoder weights provide
201                 steering directions.
202            model_name: HuggingFace model identifier for generation (e.g.,
203                       ``"google/gemma-2b"``).
204            layer: The transformer layer index where MLP activations are
205                   intercepted and modified.
206            revision: Model revision to load.
207            torch_dtype: Weight dtype for model loading.
208            device_map: Device mapping strategy.
209            trust_remote_code: Whether to trust remote code from the repo.
210            token: HuggingFace token for gated models. Falls back to
211                   ``HUGGINGFACE_HUB_TOKEN`` from ``.env`` when ``None``.
212        """
213        if token is None:
214            settings = get_settings()
215            token = settings.huggingface_hub_token
216
217        self.sae = sae
218        self.target_layer = layer
219        self.device = next(sae.parameters()).device
220
221        logger.info(f"Loading model '{model_name}' for steering at layer {layer}")
222
223        dtype_map = {
224            "float16": torch.float16,
225            "bfloat16": torch.bfloat16,
226            "float32": torch.float32,
227        }
228        load_dtype = dtype_map.get(torch_dtype, torch.float16)
229
230        model_kwargs = {
231            "revision": revision,
232            "torch_dtype": load_dtype,
233            "device_map": device_map,
234            "trust_remote_code": trust_remote_code,
235        }
236        if token:
237            model_kwargs["token"] = token
238
239        self.model = LanguageModel(model_name, **model_kwargs)
240        self.tokenizer = AutoTokenizer.from_pretrained(
241            model_name,
242            revision=revision,
243            trust_remote_code=trust_remote_code,
244            token=token,
245        )
246
247        if self.tokenizer.pad_token is None:
248            self.tokenizer.pad_token = self.tokenizer.eos_token
249
250        logger.info(f"Model loaded on {self.model.device}")
251
252    def _tokenize(self, text: str) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
253        """Tokenize text and move tensors to the model device.
254
255        Centralises tokenization so that :meth:`generate` and
256        :meth:`_generate_baseline` don't duplicate this work.
257
258        Args:
259            text: Raw prompt string.
260
261        Returns:
262            A tuple of ``(input_ids, attention_mask)`` tensors on
263            the model's device.  ``attention_mask`` may be ``None``
264            if the tokenizer does not produce one.
265        """
266        inputs = self.tokenizer(
267            text, return_tensors="pt", padding=True, truncation=True
268        )
269        input_ids = inputs["input_ids"].to(self.model.device)
270        attention_mask = inputs.get("attention_mask", None)
271        if attention_mask is not None:
272            attention_mask = attention_mask.to(self.model.device)
273        return input_ids, attention_mask
274
275    def _generate_with_hooks(
276        self,
277        input_ids: torch.Tensor,
278        attention_mask: Optional[torch.Tensor],
279        steering_direction: torch.Tensor,
280        strength: float,
281        max_new_tokens: int,
282        temperature: float,
283        top_p: float,
284    ) -> str:
285        """
286        Generate text using forward-hook-based MLP intervention.
287
288        Registers a forward hook on the target MLP module that adds
289        ``strength * steering_direction`` to the output at every
290        forward pass.  The hook is removed in a ``finally`` block to
291        guarantee cleanup even on error.
292
293        Generation is performed token-by-token using the underlying
294        HuggingFace model (``self.model._module``) with top-p
295        (nucleus) sampling.
296
297        Args:
298            input_ids: Tokenized input tensor of shape ``(1, seq_len)``.
299            attention_mask: Optional attention mask tensor.
300            steering_direction: The combined SAE decoder weight vector
301                of shape ``(activation_dim,)``.
302            strength: Multiplicative steering magnitude (typically ``1.0``
303                because strengths are baked into *steering_direction* by
304                the caller).
305            max_new_tokens: Maximum tokens to generate.
306            temperature: Sampling temperature.
307            top_p: Nucleus sampling threshold.
308
309        Returns:
310            The decoded text string (special tokens stripped).
311        """
312        generated_tokens = input_ids.clone()
313        steering_direction = steering_direction.to(generated_tokens.device)
314        target_module = resolve_module_path(
315            self.model, f"model.layers[{self.target_layer}].mlp"
316        )
317        hook_handle = None
318
319        try:
320
321            def steering_hook(module, input_args, output):
322                if isinstance(output, torch.Tensor):
323                    output = output + strength * steering_direction
324                elif isinstance(output, tuple):
325                    output = tuple(
326                        o + strength * steering_direction
327                        if isinstance(o, torch.Tensor)
328                        else o
329                        for o in output
330                    )
331                return output
332
333            hook_handle = target_module.register_forward_hook(steering_hook)
334
335            base_model = self.model._module
336            device = generated_tokens.device
337
338            for _ in tqdm(range(max_new_tokens), desc="Steering generation"):
339                if generated_tokens.shape[-1] >= self.tokenizer.model_max_length:
340                    break
341
342                outputs = base_model(
343                    input_ids=generated_tokens,
344                    attention_mask=attention_mask,
345                    use_cache=True,
346                )
347
348                logits = outputs.logits[:, -1, :]
349                next_token = _sample_next_token(logits, temperature, top_p)
350
351                if next_token.item() == self.tokenizer.eos_token_id:
352                    break
353
354                generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
355                if attention_mask is not None:
356                    attention_mask = torch.cat(
357                        [attention_mask, torch.ones(1, 1, device=device)], dim=-1
358                    )
359
360        finally:
361            if hook_handle is not None:
362                hook_handle.remove()
363
364        return self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
365
366    def generate(
367        self,
368        text: str,
369        feature_idx: Optional[int] = None,
370        strength: float = 1.0,
371        max_new_tokens: int = 50,
372        temperature: float = 0.8,
373        top_p: float = 0.9,
374        feature_indices: Optional[List[int]] = None,
375        strengths: Optional[List[float]] = None,
376    ) -> str:
377        """
378        Generate text with optional SAE feature steering.
379
380        This is the main entry point for generation.
381
382        - **No features** → calls :meth:`_generate_baseline` (plain
383          autoregressive generation).
384        - **Single feature** → pass ``feature_idx`` and ``strength``.
385        - **Multiple features** → pass ``feature_indices`` and
386          ``strengths`` as equal-length lists; the weighted decoder
387          directions are summed into a single combined vector.
388
389        Args:
390            text: The prompt text to generate from.
391            feature_idx: Index of a single SAE feature to steer with.
392            strength: Steering magnitude for a single feature
393                (default ``1.0``).
394            max_new_tokens: Maximum number of tokens to generate.
395            temperature: Sampling temperature.
396            top_p: Nucleus sampling threshold.
397            feature_indices: List of feature indices for multi-feature
398                steering.
399            strengths: List of steering magnitudes, one per entry in
400                ``feature_indices``.  Defaults to ``[1.0, …]`` if
401                ``None``.
402
403        Returns:
404            The generated text string (special tokens stripped).
405
406        Raises:
407            ValueError: If ``feature_indices`` and ``strengths`` have
408                mismatched lengths, or a feature index is out of range.
409
410        Example:
411            ```python
412            # Single feature
413            steering.generate("Hello", feature_idx=42, strength=2.0)
414
415            # Multi-feature
416            steering.generate(
417                "Hello",
418                feature_indices=[10, 42],
419                strengths=[1.0, 3.0],
420            )
421
422            # Baseline
423            steering.generate("Hello", max_new_tokens=100)
424            ```
425        """
426        if feature_idx is None and feature_indices is None:
427            return self._generate_baseline(text, max_new_tokens, temperature, top_p)
428
429        if feature_idx is not None:
430            feature_indices = [feature_idx]
431            strengths = [strength]
432        elif strengths is None:
433            strengths = [1.0] * len(feature_indices)
434
435        if len(feature_indices) != len(strengths):
436            raise ValueError("feature_indices and strengths must have the same length")
437
438        combined_direction = torch.zeros(self.sae.activation_dim, device=self.device)
439        for fid, s in zip(feature_indices, strengths):
440            if fid < 0 or fid >= self.sae.hidden_dim:
441                raise ValueError(
442                    f"feature_idx {fid} out of range [0, {self.sae.hidden_dim})"
443                )
444            combined_direction += s * self.sae.decoder.weight[:, fid]
445
446        input_ids, attention_mask = self._tokenize(text)
447
448        return self._generate_with_hooks(
449            input_ids,
450            attention_mask,
451            combined_direction,
452            1.0,
453            max_new_tokens,
454            temperature,
455            top_p,
456        )
457
458    def _generate_baseline(
459        self,
460        text: str,
461        max_new_tokens: int,
462        temperature: float,
463        top_p: float,
464    ) -> str:
465        """
466        Generate text without any SAE steering (baseline).
467
468        Runs the same token-by-token autoregressive loop as
469        :meth:`_generate_with_hooks` but without registering a
470        steering hook.
471
472        Args:
473            text: The prompt text.
474            max_new_tokens: Maximum tokens to generate.
475            temperature: Sampling temperature.
476            top_p: Nucleus sampling threshold.
477
478        Returns:
479            The generated text string (special tokens stripped).
480        """
481        input_ids, attention_mask = self._tokenize(text)
482
483        base_model = self.model._module
484        generated_tokens = input_ids.clone()
485
486        with torch.no_grad():
487            for _ in tqdm(range(max_new_tokens), desc="Baseline generation"):
488                if generated_tokens.shape[-1] >= self.tokenizer.model_max_length:
489                    break
490
491                outputs = base_model(
492                    input_ids=generated_tokens,
493                    attention_mask=attention_mask,
494                    use_cache=True,
495                )
496
497                logits = outputs.logits[:, -1, :]
498                next_token = _sample_next_token(logits, temperature, top_p)
499
500                if next_token.item() == self.tokenizer.eos_token_id:
501                    break
502
503                generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
504                if attention_mask is not None:
505                    attention_mask = torch.cat(
506                        [
507                            attention_mask,
508                            torch.ones(1, 1, device=generated_tokens.device),
509                        ],
510                        dim=-1,
511                    )
512
513        return self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
514
515    def compare_steering(
516        self,
517        text: str,
518        feature_idx: int,
519        strengths: Optional[List[float]] = None,
520        max_new_tokens: int = 50,
521        temperature: float = 0.8,
522        top_p: float = 0.9,
523    ) -> Dict[str, str]:
524        """
525        Compare baseline generation against steered generations at different strengths.
526
527        Runs one generation per strength value (including ``0.0`` for a
528        no-steering baseline) and returns a dict mapping descriptive
529        labels to the generated text.
530
531        Args:
532            text: The prompt text.
533            feature_idx: SAE feature index to steer with.
534            strengths: List of strength values to test.  Defaults to
535                ``[0.0, 1.0, 2.0, 3.0, 5.0]``.
536            max_new_tokens: Maximum tokens to generate per run.
537            temperature: Sampling temperature.
538            top_p: Nucleus sampling threshold.
539
540        Returns:
541            Dictionary mapping strength labels to generated text.
542            Keys are ``"baseline"`` (for strength ``0.0``) and
543            ``"strength_1.0"``, ``"strength_2.0"``, etc.
544
545        Example:
546            ```python
547            results = steering.compare_steering(
548                "The sky is",
549                feature_idx=42,
550                strengths=[0.0, 1.0, 3.0],
551            )
552            for label, text in results.items():
553                print(f"{label}: {text}")
554            ```
555        """
556        if strengths is None:
557            strengths = [0.0, 1.0, 2.0, 3.0, 5.0]
558
559        results = {}
560
561        for s in strengths:
562            if s == 0.0:
563                label = "baseline"
564            else:
565                label = f"strength_{s}"
566
567            logger.info(f"Generating with {label} (strength={s})")
568            generated = self.generate(
569                text,
570                feature_idx=feature_idx,
571                strength=s,
572                max_new_tokens=max_new_tokens,
573                temperature=temperature,
574                top_p=top_p,
575            )
576            results[label] = generated
577
578        return results
579
580    def find_steering_features(
581        self,
582        text: str,
583        activations: np.ndarray,
584        top_k: int = 20,
585        min_activation: float = 0.0,
586    ) -> List[Dict[str, Any]]:
587        """
588        Find the top-k SAE features that most activate on given activations.
589
590        Encodes the provided activations through the SAE and returns the
591        features with the highest mean activation values, sorted by
592        magnitude.  Useful for identifying which features are relevant
593        to a given input before steering.
594
595        Args:
596            text: The prompt text (used for logging only).
597            activations: MLP activations array of shape
598                ``(n_samples, activation_dim)``.
599            top_k: Number of top features to return.
600            min_activation: Minimum mean activation threshold to
601                consider a feature.
602
603        Returns:
604            List of dicts sorted by descending activation, each
605            containing:
606
607            - ``feature_idx`` (``int``): The feature index.
608            - ``activation`` (``float``): Mean activation value.
609            - ``normalized_activation`` (``float``): Activation
610              divided by the decoder weight L2 norm.
611            - ``decoder_weight_norm`` (``float``): L2 norm of the
612              decoder weight vector.
613        """
614        self.sae.eval()
615        device = self.device
616
617        with torch.no_grad():
618            if isinstance(activations, np.ndarray):
619                acts_tensor = torch.from_numpy(activations).float().to(device)
620            else:
621                acts_tensor = activations.to(device)
622
623            features = self.sae.encode(acts_tensor)
624
625            avg_activations = features.mean(dim=0).cpu().numpy()
626            decoder_norms = self.sae.decoder.weight.norm(dim=0).cpu().numpy()
627
628            active_mask = avg_activations >= min_activation
629            active_indices = np.where(active_mask)[0]
630
631            sorted_indices = active_indices[
632                np.argsort(avg_activations[active_indices])[::-1]
633            ][:top_k]
634
635            results = []
636            for idx in sorted_indices:
637                results.append(
638                    {
639                        "feature_idx": int(idx),
640                        "activation": float(avg_activations[idx]),
641                        "normalized_activation": float(
642                            avg_activations[idx] / (decoder_norms[idx] + 1e-8)
643                        ),
644                        "decoder_weight_norm": float(decoder_norms[idx]),
645                    }
646                )
647
648        logger.info(f"Found {len(results)} active features for text: {text[:100]}")
649        return results
650
651    def get_steering_direction(
652        self,
653        feature_idx: int,
654        normalize: bool = True,
655    ) -> np.ndarray:
656        """
657        Get the steering direction vector for a given feature.
658
659        The steering direction is the SAE decoder weight column for the
660        specified feature.  This vector represents the direction in MLP
661        activation space that the feature encodes.
662
663        Args:
664            feature_idx: Index of the SAE feature.
665            normalize: If ``True``, L2-normalise the direction vector.
666
667        Returns:
668            NumPy array of shape ``(activation_dim,)`` containing the
669            steering direction.
670
671        Raises:
672            ValueError: If ``feature_idx`` is out of range
673                ``[0, sae.hidden_dim)``.
674        """
675        if feature_idx < 0 or feature_idx >= self.sae.hidden_dim:
676            raise ValueError(
677                f"feature_idx {feature_idx} out of range [0, {self.sae.hidden_dim})"
678            )
679
680        direction = self.sae.decoder.weight[:, feature_idx].detach().cpu().numpy()
681
682        if normalize:
683            norm = np.linalg.norm(direction)
684            if norm > 0:
685                direction = direction / norm
686
687        return direction
688
689    def save_steering_analysis(
690        self,
691        results: Dict[str, str],
692        output_dir: Union[str, Path],
693        prompt: str = "",
694        feature_label: str = "",
695    ) -> Path:
696        """
697        Save steering comparison results to disk.
698
699        Writes a text file containing the prompt and all generated
700        outputs (baseline and steered) for later review.  The
701        filename includes a timestamp (and optional feature label)
702        to prevent overwriting previous analyses.
703
704        Args:
705            results: Dictionary mapping labels to generated text
706                (as returned by :meth:`compare_steering`).
707            output_dir: Directory to save the analysis file.
708            prompt: The original prompt text.
709            feature_label: Optional label included in the filename
710                (e.g. ``"feature_42"``).
711
712        Returns:
713            Path to the saved analysis file.
714        """
715        output_dir = Path(output_dir)
716        output_dir.mkdir(parents=True, exist_ok=True)
717
718        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
719        suffix = f"_{feature_label}" if feature_label else ""
720        filepath = output_dir / f"steering_analysis{suffix}_{timestamp}.txt"
721
722        with open(filepath, "w") as f:
723            f.write(f"Prompt: {prompt}\n")
724            f.write("=" * 80 + "\n\n")
725
726            for label, text in results.items():
727                f.write(f"--- {label} ---\n")
728                f.write(f"{text}\n\n")
729                f.write("-" * 40 + "\n")
730
731        logger.info(f"Saved steering analysis to {filepath}")
732        return filepath

Steer language model generation using trained SAE features.

Wraps a HuggingFace language model (via nnsight) and applies SAE feature interventions during generation. At each generation step, a forward hook on the target MLP module adds a scaled decoder weight vector, biasing the residual stream toward the chosen feature direction.

The HuggingFace auth token is read automatically from HUGGINGFACE_HUB_TOKEN (via ~drrik.settings.get_settings()) when the token parameter is not explicitly provided.

Attributes:
  • sae: The trained ~drrik.autoencoder.SparseAutoencoder providing feature directions.
  • model: The nnsight ~nnsight.LanguageModel wrapper.
  • tokenizer: The HuggingFace tokenizer for the model.
  • target_layer: The layer index where interventions are applied.
  • device: The device the SAE parameters live on.
Example:
sae = SparseAutoencoder.load("sae_model.pt")
steering = SAESteering(sae, model_name="google/gemma-2b", layer=5)

# Steered generation
result = steering.generate(
    "The sky is",
    feature_idx=128,
    strength=2.5,
    max_new_tokens=50,
)

# Baseline (no steering)
baseline = steering.generate("The sky is", max_new_tokens=50)

# Multi-feature steering
result = steering.generate(
    "The sky is",
    feature_indices=[10, 128],
    strengths=[1.0, 2.5],
)
SAESteering( sae: SparseAutoencoder, model_name: str, layer: int, revision: str = 'main', torch_dtype: str = 'float16', device_map: str = 'auto', trust_remote_code: bool = True, token: Optional[str] = None)
182    def __init__(
183        self,
184        sae: SparseAutoencoder,
185        model_name: str,
186        layer: int,
187        revision: str = "main",
188        torch_dtype: str = "float16",
189        device_map: str = "auto",
190        trust_remote_code: bool = True,
191        token: Optional[str] = None,
192    ):
193        """
194        Initialize the SAE steering controller.
195
196        If ``token`` is not provided, it is read from the environment via
197        ``get_settings()`` (i.e. ``HUGGINGFACE_HUB_TOKEN`` in ``.env``).
198
199        Args:
200            sae: A trained SparseAutoencoder whose decoder weights provide
201                 steering directions.
202            model_name: HuggingFace model identifier for generation (e.g.,
203                       ``"google/gemma-2b"``).
204            layer: The transformer layer index where MLP activations are
205                   intercepted and modified.
206            revision: Model revision to load.
207            torch_dtype: Weight dtype for model loading.
208            device_map: Device mapping strategy.
209            trust_remote_code: Whether to trust remote code from the repo.
210            token: HuggingFace token for gated models. Falls back to
211                   ``HUGGINGFACE_HUB_TOKEN`` from ``.env`` when ``None``.
212        """
213        if token is None:
214            settings = get_settings()
215            token = settings.huggingface_hub_token
216
217        self.sae = sae
218        self.target_layer = layer
219        self.device = next(sae.parameters()).device
220
221        logger.info(f"Loading model '{model_name}' for steering at layer {layer}")
222
223        dtype_map = {
224            "float16": torch.float16,
225            "bfloat16": torch.bfloat16,
226            "float32": torch.float32,
227        }
228        load_dtype = dtype_map.get(torch_dtype, torch.float16)
229
230        model_kwargs = {
231            "revision": revision,
232            "torch_dtype": load_dtype,
233            "device_map": device_map,
234            "trust_remote_code": trust_remote_code,
235        }
236        if token:
237            model_kwargs["token"] = token
238
239        self.model = LanguageModel(model_name, **model_kwargs)
240        self.tokenizer = AutoTokenizer.from_pretrained(
241            model_name,
242            revision=revision,
243            trust_remote_code=trust_remote_code,
244            token=token,
245        )
246
247        if self.tokenizer.pad_token is None:
248            self.tokenizer.pad_token = self.tokenizer.eos_token
249
250        logger.info(f"Model loaded on {self.model.device}")

Initialize the SAE steering controller.

If token is not provided, it is read from the environment via get_settings() (i.e. HUGGINGFACE_HUB_TOKEN in .env).

Arguments:
  • sae: A trained SparseAutoencoder whose decoder weights provide steering directions.
  • model_name: HuggingFace model identifier for generation (e.g., "google/gemma-2b").
  • layer: The transformer layer index where MLP activations are intercepted and modified.
  • revision: Model revision to load.
  • torch_dtype: Weight dtype for model loading.
  • device_map: Device mapping strategy.
  • trust_remote_code: Whether to trust remote code from the repo.
  • token: HuggingFace token for gated models. Falls back to HUGGINGFACE_HUB_TOKEN from .env when None.
sae
target_layer
device
model
tokenizer
def generate( self, text: str, feature_idx: Optional[int] = None, strength: float = 1.0, max_new_tokens: int = 50, temperature: float = 0.8, top_p: float = 0.9, feature_indices: Optional[List[int]] = None, strengths: Optional[List[float]] = None) -> str:
366    def generate(
367        self,
368        text: str,
369        feature_idx: Optional[int] = None,
370        strength: float = 1.0,
371        max_new_tokens: int = 50,
372        temperature: float = 0.8,
373        top_p: float = 0.9,
374        feature_indices: Optional[List[int]] = None,
375        strengths: Optional[List[float]] = None,
376    ) -> str:
377        """
378        Generate text with optional SAE feature steering.
379
380        This is the main entry point for generation.
381
382        - **No features** → calls :meth:`_generate_baseline` (plain
383          autoregressive generation).
384        - **Single feature** → pass ``feature_idx`` and ``strength``.
385        - **Multiple features** → pass ``feature_indices`` and
386          ``strengths`` as equal-length lists; the weighted decoder
387          directions are summed into a single combined vector.
388
389        Args:
390            text: The prompt text to generate from.
391            feature_idx: Index of a single SAE feature to steer with.
392            strength: Steering magnitude for a single feature
393                (default ``1.0``).
394            max_new_tokens: Maximum number of tokens to generate.
395            temperature: Sampling temperature.
396            top_p: Nucleus sampling threshold.
397            feature_indices: List of feature indices for multi-feature
398                steering.
399            strengths: List of steering magnitudes, one per entry in
400                ``feature_indices``.  Defaults to ``[1.0, …]`` if
401                ``None``.
402
403        Returns:
404            The generated text string (special tokens stripped).
405
406        Raises:
407            ValueError: If ``feature_indices`` and ``strengths`` have
408                mismatched lengths, or a feature index is out of range.
409
410        Example:
411            ```python
412            # Single feature
413            steering.generate("Hello", feature_idx=42, strength=2.0)
414
415            # Multi-feature
416            steering.generate(
417                "Hello",
418                feature_indices=[10, 42],
419                strengths=[1.0, 3.0],
420            )
421
422            # Baseline
423            steering.generate("Hello", max_new_tokens=100)
424            ```
425        """
426        if feature_idx is None and feature_indices is None:
427            return self._generate_baseline(text, max_new_tokens, temperature, top_p)
428
429        if feature_idx is not None:
430            feature_indices = [feature_idx]
431            strengths = [strength]
432        elif strengths is None:
433            strengths = [1.0] * len(feature_indices)
434
435        if len(feature_indices) != len(strengths):
436            raise ValueError("feature_indices and strengths must have the same length")
437
438        combined_direction = torch.zeros(self.sae.activation_dim, device=self.device)
439        for fid, s in zip(feature_indices, strengths):
440            if fid < 0 or fid >= self.sae.hidden_dim:
441                raise ValueError(
442                    f"feature_idx {fid} out of range [0, {self.sae.hidden_dim})"
443                )
444            combined_direction += s * self.sae.decoder.weight[:, fid]
445
446        input_ids, attention_mask = self._tokenize(text)
447
448        return self._generate_with_hooks(
449            input_ids,
450            attention_mask,
451            combined_direction,
452            1.0,
453            max_new_tokens,
454            temperature,
455            top_p,
456        )

Generate text with optional SAE feature steering.

This is the main entry point for generation.

  • No features → calls _generate_baseline() (plain autoregressive generation).
  • Single feature → pass feature_idx and strength.
  • Multiple features → pass feature_indices and strengths as equal-length lists; the weighted decoder directions are summed into a single combined vector.
Arguments:
  • text: The prompt text to generate from.
  • feature_idx: Index of a single SAE feature to steer with.
  • strength: Steering magnitude for a single feature (default 1.0).
  • max_new_tokens: Maximum number of tokens to generate.
  • temperature: Sampling temperature.
  • top_p: Nucleus sampling threshold.
  • feature_indices: List of feature indices for multi-feature steering.
  • strengths: List of steering magnitudes, one per entry in feature_indices. Defaults to [1.0, …] if None.
Returns:

The generated text string (special tokens stripped).

Raises:
  • ValueError: If feature_indices and strengths have mismatched lengths, or a feature index is out of range.
Example:
# Single feature
steering.generate("Hello", feature_idx=42, strength=2.0)

# Multi-feature
steering.generate(
    "Hello",
    feature_indices=[10, 42],
    strengths=[1.0, 3.0],
)

# Baseline
steering.generate("Hello", max_new_tokens=100)
def compare_steering( self, text: str, feature_idx: int, strengths: Optional[List[float]] = None, max_new_tokens: int = 50, temperature: float = 0.8, top_p: float = 0.9) -> Dict[str, str]:
515    def compare_steering(
516        self,
517        text: str,
518        feature_idx: int,
519        strengths: Optional[List[float]] = None,
520        max_new_tokens: int = 50,
521        temperature: float = 0.8,
522        top_p: float = 0.9,
523    ) -> Dict[str, str]:
524        """
525        Compare baseline generation against steered generations at different strengths.
526
527        Runs one generation per strength value (including ``0.0`` for a
528        no-steering baseline) and returns a dict mapping descriptive
529        labels to the generated text.
530
531        Args:
532            text: The prompt text.
533            feature_idx: SAE feature index to steer with.
534            strengths: List of strength values to test.  Defaults to
535                ``[0.0, 1.0, 2.0, 3.0, 5.0]``.
536            max_new_tokens: Maximum tokens to generate per run.
537            temperature: Sampling temperature.
538            top_p: Nucleus sampling threshold.
539
540        Returns:
541            Dictionary mapping strength labels to generated text.
542            Keys are ``"baseline"`` (for strength ``0.0``) and
543            ``"strength_1.0"``, ``"strength_2.0"``, etc.
544
545        Example:
546            ```python
547            results = steering.compare_steering(
548                "The sky is",
549                feature_idx=42,
550                strengths=[0.0, 1.0, 3.0],
551            )
552            for label, text in results.items():
553                print(f"{label}: {text}")
554            ```
555        """
556        if strengths is None:
557            strengths = [0.0, 1.0, 2.0, 3.0, 5.0]
558
559        results = {}
560
561        for s in strengths:
562            if s == 0.0:
563                label = "baseline"
564            else:
565                label = f"strength_{s}"
566
567            logger.info(f"Generating with {label} (strength={s})")
568            generated = self.generate(
569                text,
570                feature_idx=feature_idx,
571                strength=s,
572                max_new_tokens=max_new_tokens,
573                temperature=temperature,
574                top_p=top_p,
575            )
576            results[label] = generated
577
578        return results

Compare baseline generation against steered generations at different strengths.

Runs one generation per strength value (including 0.0 for a no-steering baseline) and returns a dict mapping descriptive labels to the generated text.

Arguments:
  • text: The prompt text.
  • feature_idx: SAE feature index to steer with.
  • strengths: List of strength values to test. Defaults to [0.0, 1.0, 2.0, 3.0, 5.0].
  • max_new_tokens: Maximum tokens to generate per run.
  • temperature: Sampling temperature.
  • top_p: Nucleus sampling threshold.
Returns:

Dictionary mapping strength labels to generated text. Keys are "baseline" (for strength 0.0) and "strength_1.0", "strength_2.0", etc.

Example:
results = steering.compare_steering(
    "The sky is",
    feature_idx=42,
    strengths=[0.0, 1.0, 3.0],
)
for label, text in results.items():
    print(f"{label}: {text}")
def find_steering_features( self, text: str, activations: numpy.ndarray, top_k: int = 20, min_activation: float = 0.0) -> List[Dict[str, Any]]:
580    def find_steering_features(
581        self,
582        text: str,
583        activations: np.ndarray,
584        top_k: int = 20,
585        min_activation: float = 0.0,
586    ) -> List[Dict[str, Any]]:
587        """
588        Find the top-k SAE features that most activate on given activations.
589
590        Encodes the provided activations through the SAE and returns the
591        features with the highest mean activation values, sorted by
592        magnitude.  Useful for identifying which features are relevant
593        to a given input before steering.
594
595        Args:
596            text: The prompt text (used for logging only).
597            activations: MLP activations array of shape
598                ``(n_samples, activation_dim)``.
599            top_k: Number of top features to return.
600            min_activation: Minimum mean activation threshold to
601                consider a feature.
602
603        Returns:
604            List of dicts sorted by descending activation, each
605            containing:
606
607            - ``feature_idx`` (``int``): The feature index.
608            - ``activation`` (``float``): Mean activation value.
609            - ``normalized_activation`` (``float``): Activation
610              divided by the decoder weight L2 norm.
611            - ``decoder_weight_norm`` (``float``): L2 norm of the
612              decoder weight vector.
613        """
614        self.sae.eval()
615        device = self.device
616
617        with torch.no_grad():
618            if isinstance(activations, np.ndarray):
619                acts_tensor = torch.from_numpy(activations).float().to(device)
620            else:
621                acts_tensor = activations.to(device)
622
623            features = self.sae.encode(acts_tensor)
624
625            avg_activations = features.mean(dim=0).cpu().numpy()
626            decoder_norms = self.sae.decoder.weight.norm(dim=0).cpu().numpy()
627
628            active_mask = avg_activations >= min_activation
629            active_indices = np.where(active_mask)[0]
630
631            sorted_indices = active_indices[
632                np.argsort(avg_activations[active_indices])[::-1]
633            ][:top_k]
634
635            results = []
636            for idx in sorted_indices:
637                results.append(
638                    {
639                        "feature_idx": int(idx),
640                        "activation": float(avg_activations[idx]),
641                        "normalized_activation": float(
642                            avg_activations[idx] / (decoder_norms[idx] + 1e-8)
643                        ),
644                        "decoder_weight_norm": float(decoder_norms[idx]),
645                    }
646                )
647
648        logger.info(f"Found {len(results)} active features for text: {text[:100]}")
649        return results

Find the top-k SAE features that most activate on given activations.

Encodes the provided activations through the SAE and returns the features with the highest mean activation values, sorted by magnitude. Useful for identifying which features are relevant to a given input before steering.

Arguments:
  • text: The prompt text (used for logging only).
  • activations: MLP activations array of shape (n_samples, activation_dim).
  • top_k: Number of top features to return.
  • min_activation: Minimum mean activation threshold to consider a feature.
Returns:

List of dicts sorted by descending activation, each containing:

  • feature_idx (int): The feature index.
  • activation (float): Mean activation value.
  • normalized_activation (float): Activation divided by the decoder weight L2 norm.
  • decoder_weight_norm (float): L2 norm of the decoder weight vector.
def get_steering_direction(self, feature_idx: int, normalize: bool = True) -> numpy.ndarray:
651    def get_steering_direction(
652        self,
653        feature_idx: int,
654        normalize: bool = True,
655    ) -> np.ndarray:
656        """
657        Get the steering direction vector for a given feature.
658
659        The steering direction is the SAE decoder weight column for the
660        specified feature.  This vector represents the direction in MLP
661        activation space that the feature encodes.
662
663        Args:
664            feature_idx: Index of the SAE feature.
665            normalize: If ``True``, L2-normalise the direction vector.
666
667        Returns:
668            NumPy array of shape ``(activation_dim,)`` containing the
669            steering direction.
670
671        Raises:
672            ValueError: If ``feature_idx`` is out of range
673                ``[0, sae.hidden_dim)``.
674        """
675        if feature_idx < 0 or feature_idx >= self.sae.hidden_dim:
676            raise ValueError(
677                f"feature_idx {feature_idx} out of range [0, {self.sae.hidden_dim})"
678            )
679
680        direction = self.sae.decoder.weight[:, feature_idx].detach().cpu().numpy()
681
682        if normalize:
683            norm = np.linalg.norm(direction)
684            if norm > 0:
685                direction = direction / norm
686
687        return direction

Get the steering direction vector for a given feature.

The steering direction is the SAE decoder weight column for the specified feature. This vector represents the direction in MLP activation space that the feature encodes.

Arguments:
  • feature_idx: Index of the SAE feature.
  • normalize: If True, L2-normalise the direction vector.
Returns:

NumPy array of shape (activation_dim,) containing the steering direction.

Raises:
  • ValueError: If feature_idx is out of range [0, sae.hidden_dim).
def save_steering_analysis( self, results: Dict[str, str], output_dir: Union[str, pathlib.Path], prompt: str = '', feature_label: str = '') -> pathlib.Path:
689    def save_steering_analysis(
690        self,
691        results: Dict[str, str],
692        output_dir: Union[str, Path],
693        prompt: str = "",
694        feature_label: str = "",
695    ) -> Path:
696        """
697        Save steering comparison results to disk.
698
699        Writes a text file containing the prompt and all generated
700        outputs (baseline and steered) for later review.  The
701        filename includes a timestamp (and optional feature label)
702        to prevent overwriting previous analyses.
703
704        Args:
705            results: Dictionary mapping labels to generated text
706                (as returned by :meth:`compare_steering`).
707            output_dir: Directory to save the analysis file.
708            prompt: The original prompt text.
709            feature_label: Optional label included in the filename
710                (e.g. ``"feature_42"``).
711
712        Returns:
713            Path to the saved analysis file.
714        """
715        output_dir = Path(output_dir)
716        output_dir.mkdir(parents=True, exist_ok=True)
717
718        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
719        suffix = f"_{feature_label}" if feature_label else ""
720        filepath = output_dir / f"steering_analysis{suffix}_{timestamp}.txt"
721
722        with open(filepath, "w") as f:
723            f.write(f"Prompt: {prompt}\n")
724            f.write("=" * 80 + "\n\n")
725
726            for label, text in results.items():
727                f.write(f"--- {label} ---\n")
728                f.write(f"{text}\n\n")
729                f.write("-" * 40 + "\n")
730
731        logger.info(f"Saved steering analysis to {filepath}")
732        return filepath

Save steering comparison results to disk.

Writes a text file containing the prompt and all generated outputs (baseline and steered) for later review. The filename includes a timestamp (and optional feature label) to prevent overwriting previous analyses.

Arguments:
  • results: Dictionary mapping labels to generated text (as returned by compare_steering()).
  • output_dir: Directory to save the analysis file.
  • prompt: The original prompt text.
  • feature_label: Optional label included in the filename (e.g. "feature_42").
Returns:

Path to the saved analysis file.

class Config(pydantic_settings.main.BaseSettings):
276class Config(BaseSettings):
277    """Main configuration class for the Drrik framework.
278
279    Aggregates all sub-configurations (extractor, autoencoder, visualization)
280    into a single object. Settings can be loaded from environment variables
281    (with ``DRIK_`` prefix) or from a YAML config file via the CLI.
282
283    Attributes:
284        extractor: Activation extraction configuration.
285        autoencoder: Sparse autoencoder configuration.
286        visualization: Feature visualization configuration.
287        random_seed: Global random seed for reproducibility. Applied
288            to numpy, torch, and random at pipeline start.
289        log_level: Logging verbosity. One of ``DEBUG``, ``INFO``,
290            ``WARNING``, ``ERROR``, ``CRITICAL``.
291
292    Example:
293        Load from environment variables:
294
295        ```bash
296        export DRIK_RANDOM_SEED=42
297        export DRIK_AUTOENCODER__HIDDEN_DIM=16384
298        config = Config()
299        ```
300
301        Or construct programmatically:
302
303        ```python
304        config = Config(
305            extractor=ActivationExtractorConfig(...),
306            autoencoder=SparseAutoencoderConfig(activation_dim=2048),
307        )
308        ```
309    """
310
311    extractor: ActivationExtractorConfig = Field(
312        default_factory=ActivationExtractorConfig
313    )
314    autoencoder: SparseAutoencoderConfig = Field(
315        default_factory=SparseAutoencoderConfig
316    )
317    visualization: VisualizationConfig = Field(default_factory=VisualizationConfig)
318
319    random_seed: int = Field(default=42, description="Random seed for reproducibility")
320    log_level: str = Field(
321        default="INFO", description="Logging level (DEBUG, INFO, WARNING, ERROR)"
322    )
323
324    model_config = ConfigDict(
325        env_prefix="DRIK_",
326        env_nested_delimiter="__",
327    )
328
329    def create_output_dirs(self) -> None:
330        """Create output directories if they don't exist.
331
332        Ensures that both the extractor output directory (if configured)
333        and the visualization output directory exist on disk, creating
334        any missing parent directories as needed.
335        """
336        if self.extractor.output_dir:
337            self.extractor.output_dir.mkdir(parents=True, exist_ok=True)
338        self.visualization.output_dir.mkdir(parents=True, exist_ok=True)

Main configuration class for the Drrik framework.

Aggregates all sub-configurations (extractor, autoencoder, visualization) into a single object. Settings can be loaded from environment variables (with DRIK_ prefix) or from a YAML config file via the CLI.

Attributes:
  • extractor: Activation extraction configuration.
  • autoencoder: Sparse autoencoder configuration.
  • visualization: Feature visualization configuration.
  • random_seed: Global random seed for reproducibility. Applied to numpy, torch, and random at pipeline start.
  • log_level: Logging verbosity. One of DEBUG, INFO, WARNING, ERROR, CRITICAL.
Example:

Load from environment variables:

export DRIK_RANDOM_SEED=42
export DRIK_AUTOENCODER__HIDDEN_DIM=16384
config = Config()

Or construct programmatically:

config = Config(
    extractor=ActivationExtractorConfig(...),
    autoencoder=SparseAutoencoderConfig(activation_dim=2048),
)
extractor: drrik.config.ActivationExtractorConfig = PydanticUndefined
autoencoder: drrik.config.SparseAutoencoderConfig = PydanticUndefined
visualization: drrik.config.VisualizationConfig = PydanticUndefined
random_seed: int = 42

Random seed for reproducibility

log_level: str = 'INFO'

Logging level (DEBUG, INFO, WARNING, ERROR)

def create_output_dirs(self) -> None:
329    def create_output_dirs(self) -> None:
330        """Create output directories if they don't exist.
331
332        Ensures that both the extractor output directory (if configured)
333        and the visualization output directory exist on disk, creating
334        any missing parent directories as needed.
335        """
336        if self.extractor.output_dir:
337            self.extractor.output_dir.mkdir(parents=True, exist_ok=True)
338        self.visualization.output_dir.mkdir(parents=True, exist_ok=True)

Create output directories if they don't exist.

Ensures that both the extractor output directory (if configured) and the visualization output directory exist on disk, creating any missing parent directories as needed.

class EnvironmentSettings(pydantic_settings.main.BaseSettings):
 32class EnvironmentSettings(BaseSettings):
 33    """
 34    Environment settings for API keys and tokens.
 35
 36    These settings are loaded from environment variables or a .env file.
 37    Create a .env file in the project root with your credentials.
 38
 39    Example .env file:
 40        ```bash
 41        HUGGINGFACE_HUB_TOKEN=hf_...
 42        WANDB_API_KEY=...
 43        WANDB_PROJECT=drrik-experiments
 44        WANDB_ENTITY=your-username
 45        ```
 46
 47    Attributes:
 48        huggingface_hub_token: HuggingFace Hub API token for gated models
 49        wandb_api_key: Weights & Biases API key for experiment tracking
 50        wandb_project: Default wandb project name
 51        wandb_entity: Default wandb entity (username or team)
 52        wandb_mode: wandb mode ('online', 'offline', or 'disabled')
 53    """
 54
 55    model_config = SettingsConfigDict(
 56        env_file=".env",
 57        env_file_encoding="utf-8",
 58        env_prefix="",
 59        case_sensitive=False,
 60        extra="ignore",
 61    )
 62
 63    huggingface_hub_token: Optional[str] = Field(
 64        default=None,
 65        description="HuggingFace Hub API token for accessing gated models. "
 66        "Get your token at: https://huggingface.co/settings/tokens",
 67    )
 68
 69    wandb_api_key: Optional[str] = Field(
 70        default=None,
 71        description="Weights & Biases API key for experiment tracking. "
 72        "Get your key at: https://wandb.ai/settings",
 73    )
 74
 75    wandb_project: str = Field(
 76        default="drrik-experiments", description="Default wandb project name"
 77    )
 78
 79    wandb_entity: Optional[str] = Field(
 80        default=None, description="Default wandb entity (username or team name)"
 81    )
 82
 83    wandb_mode: str = Field(
 84        default="online",
 85        description="wandb mode: 'online' to sync, 'offline' to save locally, "
 86        "'disabled' to disable wandb",
 87    )
 88
 89    @field_validator("wandb_mode")
 90    @classmethod
 91    def validate_wandb_mode(cls, v: str) -> str:
 92        """Validate that the wandb mode is one of the allowed values.
 93
 94        Args:
 95            v: The wandb mode string to validate.
 96
 97        Returns:
 98            The lowercased wandb mode string.
 99
100        Raises:
101            ValueError: If the mode is not ``online``, ``offline``,
102                or ``disabled``.
103        """
104        valid_modes = ["online", "offline", "disabled"]
105        v = v.lower()
106        if v not in valid_modes:
107            raise ValueError(f"wandb_mode must be one of {valid_modes}, got '{v}'")
108        return v
109
110    @field_validator("huggingface_hub_token")
111    @classmethod
112    def validate_hf_token(cls, v: Optional[str]) -> Optional[str]:
113        """Log a confirmation message when an HF token is provided.
114
115        Args:
116            v: The HuggingFace Hub token string, or ``None``.
117
118        Returns:
119            The unmodified token string.
120        """
121        if v:
122            logger.info("HuggingFace Hub token is configured")
123        return v
124
125    @field_validator("wandb_api_key")
126    @classmethod
127    def validate_wandb_key(cls, v: Optional[str]) -> Optional[str]:
128        """Log a confirmation message when a wandb API key is provided.
129
130        Args:
131            v: The Weights & Biases API key string, or ``None``.
132
133        Returns:
134            The unmodified API key string.
135        """
136        if v:
137            logger.info("Weights & Biases API key is configured")
138        return v
139
140    @property
141    def use_wandb(self) -> bool:
142        """Check if wandb should be enabled based on API key and mode.
143
144        Returns:
145            ``True`` if an API key is set and wandb mode is not
146            ``disabled``, ``False`` otherwise.
147        """
148        return self.wandb_api_key is not None and self.wandb_mode != "disabled"
149
150    @property
151    def has_hf_token(self) -> bool:
152        """Check if a HuggingFace Hub token is available.
153
154        Returns:
155            ``True`` if ``huggingface_hub_token`` is set, ``False``
156            otherwise.
157        """
158        return self.huggingface_hub_token is not None
159
160    def get_hf_auth(self) -> Optional[tuple]:
161        """Get HuggingFace authentication tuple for model loading.
162
163        Constructs the authentication argument expected by the
164        ``transformers`` library when loading gated models.
165
166        Returns:
167            A tuple of ``(True, token_string)`` if a token is
168            configured, otherwise ``None``.
169        """
170        if self.has_hf_token:
171            return (True, self.huggingface_hub_token)
172        return None

Environment settings for API keys and tokens.

These settings are loaded from environment variables or a .env file. Create a .env file in the project root with your credentials.

Example .env file:

HUGGINGFACE_HUB_TOKEN=hf_...
WANDB_API_KEY=...
WANDB_PROJECT=drrik-experiments
WANDB_ENTITY=your-username
Attributes:
  • huggingface_hub_token: HuggingFace Hub API token for gated models
  • wandb_api_key: Weights & Biases API key for experiment tracking
  • wandb_project: Default wandb project name
  • wandb_entity: Default wandb entity (username or team)
  • wandb_mode: wandb mode ('online', 'offline', or 'disabled')
huggingface_hub_token: Optional[str] = None

HuggingFace Hub API token for accessing gated models. Get your token at: https://huggingface.co/settings/tokens

wandb_api_key: Optional[str] = None

Weights & Biases API key for experiment tracking. Get your key at: https://wandb.ai/settings

wandb_project: str = 'drrik-experiments'

Default wandb project name

wandb_entity: Optional[str] = None

Default wandb entity (username or team name)

wandb_mode: str = 'online'

wandb mode: 'online' to sync, 'offline' to save locally, 'disabled' to disable wandb

@field_validator('wandb_mode')
@classmethod
def validate_wandb_mode(cls, v: str) -> str:
 89    @field_validator("wandb_mode")
 90    @classmethod
 91    def validate_wandb_mode(cls, v: str) -> str:
 92        """Validate that the wandb mode is one of the allowed values.
 93
 94        Args:
 95            v: The wandb mode string to validate.
 96
 97        Returns:
 98            The lowercased wandb mode string.
 99
100        Raises:
101            ValueError: If the mode is not ``online``, ``offline``,
102                or ``disabled``.
103        """
104        valid_modes = ["online", "offline", "disabled"]
105        v = v.lower()
106        if v not in valid_modes:
107            raise ValueError(f"wandb_mode must be one of {valid_modes}, got '{v}'")
108        return v

Validate that the wandb mode is one of the allowed values.

Arguments:
  • v: The wandb mode string to validate.
Returns:

The lowercased wandb mode string.

Raises:
  • ValueError: If the mode is not online, offline, or disabled.
@field_validator('huggingface_hub_token')
@classmethod
def validate_hf_token(cls, v: Optional[str]) -> Optional[str]:
110    @field_validator("huggingface_hub_token")
111    @classmethod
112    def validate_hf_token(cls, v: Optional[str]) -> Optional[str]:
113        """Log a confirmation message when an HF token is provided.
114
115        Args:
116            v: The HuggingFace Hub token string, or ``None``.
117
118        Returns:
119            The unmodified token string.
120        """
121        if v:
122            logger.info("HuggingFace Hub token is configured")
123        return v

Log a confirmation message when an HF token is provided.

Arguments:
  • v: The HuggingFace Hub token string, or None.
Returns:

The unmodified token string.

@field_validator('wandb_api_key')
@classmethod
def validate_wandb_key(cls, v: Optional[str]) -> Optional[str]:
125    @field_validator("wandb_api_key")
126    @classmethod
127    def validate_wandb_key(cls, v: Optional[str]) -> Optional[str]:
128        """Log a confirmation message when a wandb API key is provided.
129
130        Args:
131            v: The Weights & Biases API key string, or ``None``.
132
133        Returns:
134            The unmodified API key string.
135        """
136        if v:
137            logger.info("Weights & Biases API key is configured")
138        return v

Log a confirmation message when a wandb API key is provided.

Arguments:
  • v: The Weights & Biases API key string, or None.
Returns:

The unmodified API key string.

use_wandb: bool
140    @property
141    def use_wandb(self) -> bool:
142        """Check if wandb should be enabled based on API key and mode.
143
144        Returns:
145            ``True`` if an API key is set and wandb mode is not
146            ``disabled``, ``False`` otherwise.
147        """
148        return self.wandb_api_key is not None and self.wandb_mode != "disabled"

Check if wandb should be enabled based on API key and mode.

Returns:

True if an API key is set and wandb mode is not disabled, False otherwise.

has_hf_token: bool
150    @property
151    def has_hf_token(self) -> bool:
152        """Check if a HuggingFace Hub token is available.
153
154        Returns:
155            ``True`` if ``huggingface_hub_token`` is set, ``False``
156            otherwise.
157        """
158        return self.huggingface_hub_token is not None

Check if a HuggingFace Hub token is available.

Returns:

True if huggingface_hub_token is set, False otherwise.

def get_hf_auth(self) -> Optional[tuple]:
160    def get_hf_auth(self) -> Optional[tuple]:
161        """Get HuggingFace authentication tuple for model loading.
162
163        Constructs the authentication argument expected by the
164        ``transformers`` library when loading gated models.
165
166        Returns:
167            A tuple of ``(True, token_string)`` if a token is
168            configured, otherwise ``None``.
169        """
170        if self.has_hf_token:
171            return (True, self.huggingface_hub_token)
172        return None

Get HuggingFace authentication tuple for model loading.

Constructs the authentication argument expected by the transformers library when loading gated models.

Returns:

A tuple of (True, token_string) if a token is configured, otherwise None.

class WandbConfig:
175class WandbConfig:
176    """
177    Configuration for wandb experiment tracking.
178
179    This class handles wandb initialization, logging, and cleanup.
180    It's designed to be used as a context manager for automatic cleanup.
181
182    Example:
183        ```python
184        from drrik.settings import WandbConfig
185
186        # Use as context manager
187        with WandbConfig(
188            project="my-sae-experiment",
189            config={"model": "gemma-2b", "expansion": 8}
190        ) as wandb_logger:
191            wandb_logger.log_metrics({"loss": 0.5, "l0_norm": 10})
192
193        # Or manually initialize/finalize
194        wandb_config = WandbConfig(project="my-sae-experiment")
195        wandb_config.initialize()
196        wandb_config.log_metrics({"loss": 0.5})
197        wandb_config.finalize()
198        ```
199    """
200
201    def __init__(
202        self,
203        project: Optional[str] = None,
204        entity: Optional[str] = None,
205        name: Optional[str] = None,
206        config: Optional[dict] = None,
207        tags: Optional[list[str]] = None,
208        settings: Optional[EnvironmentSettings] = None,
209        enabled: bool = True,
210    ):
211        """
212        Initialize wandb configuration.
213
214        Args:
215            project: wandb project name (uses settings default if None)
216            entity: wandb entity (uses settings default if None)
217            name: Run name (auto-generated if None)
218            config: Configuration dict to log
219            tags: List of tags for the run
220            settings: EnvironmentSettings instance (uses global if None)
221            enabled: If False, disables wandb even if configured
222        """
223        self.settings = settings or EnvironmentSettings()
224        self.enabled = enabled and self.settings.use_wandb
225
226        if self.enabled and not self.settings.wandb_api_key:
227            logger.warning(
228                "wandb is enabled but no API key is set. "
229                "Set WANDB_API_KEY environment variable or disable wandb."
230            )
231            self.enabled = False
232
233        self.project = project or self.settings.wandb_project
234        self.entity = entity or self.settings.wandb_entity
235        self.name = name
236        self.config = config or {}
237        self.tags = tags
238
239        self._initialized = False
240        self._run = None
241
242    def initialize(self) -> bool:
243        """
244        Initialize wandb run.
245
246        Returns:
247            True if wandb was successfully initialized, False otherwise
248        """
249        if not self.enabled:
250            logger.info("wandb is disabled")
251            return False
252
253        if self._initialized:
254            logger.warning("wandb is already initialized")
255            return True
256
257        try:
258            import wandb
259
260            # Set API key
261            os.environ["WANDB_API_KEY"] = self.settings.wandb_api_key
262            os.environ["WANDB_MODE"] = self.settings.wandb_mode
263
264            # Initialize run
265            self._run = wandb.init(
266                project=self.project,
267                entity=self.entity,
268                name=self.name,
269                config=self.config,
270                tags=self.tags,
271                reinit=True,  # Allow re-initialization
272            )
273
274            self._initialized = True
275            logger.info(f"wandb initialized: {self.get_run_url()}")
276            return True
277
278        except ImportError:
279            logger.warning(
280                "wandb package not installed. Install it with: pip install wandb"
281            )
282            self.enabled = False
283            return False
284        except Exception as e:
285            logger.error(f"Failed to initialize wandb: {e}")
286            self.enabled = False
287            return False
288
289    def finalize(self) -> None:
290        """Finalize wandb run."""
291        if self._initialized:
292            try:
293                import wandb
294
295                wandb.finish()
296                logger.info("wandb run finalized")
297            except Exception as e:
298                logger.error(f"Error finalizing wandb: {e}")
299            finally:
300                self._initialized = False
301                self._run = None
302
303    def log_metrics(
304        self,
305        metrics: dict,
306        step: Optional[int] = None,
307        commit: bool = True,
308    ) -> None:
309        """
310        Log metrics to wandb.
311
312        Args:
313            metrics: Dictionary of metric names to values
314            step: Current step (for logging to specific step)
315            commit: Whether to commit the log
316        """
317        if self._initialized:
318            try:
319                import wandb
320
321                wandb.log(metrics, step=step, commit=commit)
322            except Exception as e:
323                logger.error(f"Error logging to wandb: {e}")
324
325    def log_histogram(
326        self,
327        values,
328        name: str,
329        step: Optional[int] = None,
330    ) -> None:
331        """
332        Log a histogram to wandb.
333
334        Args:
335            values: Array-like values to histogram
336            name: Name of the histogram
337            step: Current step
338        """
339        if self._initialized:
340            try:
341                import wandb
342                import numpy as np
343
344                wandb.log({name: wandb.Histogram(np.array(values))}, step=step)
345            except Exception as e:
346                logger.error(f"Error logging histogram to wandb: {e}")
347
348    def log_model(self, model_path: str, name: str = "model") -> None:
349        """
350        Log a model artifact to wandb.
351
352        Supports both single files and directories.
353
354        Args:
355            model_path: Path to the model file or directory
356            name: Artifact name
357        """
358        if self._initialized:
359            try:
360                import wandb
361                from pathlib import Path
362
363                artifact = wandb.Artifact(name, type="model")
364                path = Path(model_path)
365                if path.is_dir():
366                    artifact.add_dir(str(path), name=path.name)
367                else:
368                    artifact.add_file(str(path))
369                wandb.log_artifact(artifact)
370                logger.info(f"Logged model artifact: {name} from {model_path}")
371            except Exception as e:
372                logger.error(f"Error logging model to wandb: {e}")
373
374    def log_artifact(
375        self, path: str, name: str, artifact_type: str = "dataset"
376    ) -> None:
377        """
378        Log a dataset artifact to wandb.
379
380        Supports both single files and directories.
381
382        Args:
383            path: Path to the file or directory
384            name: Artifact name
385            artifact_type: Type of artifact (dataset, model, config, etc.)
386        """
387        if self._initialized:
388            try:
389                import wandb
390                from pathlib import Path
391
392                artifact = wandb.Artifact(name, type=artifact_type)
393                p = Path(path)
394                if p.is_dir():
395                    artifact.add_dir(str(p), name=p.name)
396                else:
397                    artifact.add_file(str(p))
398                wandb.log_artifact(artifact)
399                logger.info(f"Logged {artifact_type} artifact: {name} from {path}")
400            except Exception as e:
401                logger.error(f"Error logging artifact to wandb: {e}")
402
403    def get_run_url(self) -> Optional[str]:
404        """Get the wandb run URL for the current experiment.
405
406        Returns:
407            The URL string of the active wandb run, or ``None`` if
408            no run has been initialized.
409        """
410        if self._run:
411            return self._run.url
412        return None
413
414    def get_run_id(self) -> Optional[str]:
415        """Get the wandb run ID for the current experiment.
416
417        Returns:
418            The unique run identifier string, or ``None`` if no run
419            has been initialized.
420        """
421        if self._run:
422            return self._run.id
423        return None
424
425    def __enter__(self):
426        """Enter the context manager, initializing the wandb run.
427
428        Returns:
429            The ``WandbConfig`` instance for use within the ``with``
430            block.
431        """
432        self.initialize()
433        return self
434
435    def __exit__(self, exc_type, exc_val, exc_tb):
436        """Exit the context manager, finalizing the wandb run.
437
438        Args:
439            exc_type: Exception type, if an exception was raised.
440            exc_val: Exception value, if an exception was raised.
441            exc_tb: Exception traceback, if an exception was raised.
442
443        Returns:
444            ``False`` to propagate any exceptions.
445        """
446        self.finalize()
447        return False

Configuration for wandb experiment tracking.

This class handles wandb initialization, logging, and cleanup. It's designed to be used as a context manager for automatic cleanup.

Example:
from drrik.settings import WandbConfig

# Use as context manager
with WandbConfig(
    project="my-sae-experiment",
    config={"model": "gemma-2b", "expansion": 8}
) as wandb_logger:
    wandb_logger.log_metrics({"loss": 0.5, "l0_norm": 10})

# Or manually initialize/finalize
wandb_config = WandbConfig(project="my-sae-experiment")
wandb_config.initialize()
wandb_config.log_metrics({"loss": 0.5})
wandb_config.finalize()
WandbConfig( project: Optional[str] = None, entity: Optional[str] = None, name: Optional[str] = None, config: Optional[dict] = None, tags: Optional[list[str]] = None, settings: Optional[EnvironmentSettings] = None, enabled: bool = True)
201    def __init__(
202        self,
203        project: Optional[str] = None,
204        entity: Optional[str] = None,
205        name: Optional[str] = None,
206        config: Optional[dict] = None,
207        tags: Optional[list[str]] = None,
208        settings: Optional[EnvironmentSettings] = None,
209        enabled: bool = True,
210    ):
211        """
212        Initialize wandb configuration.
213
214        Args:
215            project: wandb project name (uses settings default if None)
216            entity: wandb entity (uses settings default if None)
217            name: Run name (auto-generated if None)
218            config: Configuration dict to log
219            tags: List of tags for the run
220            settings: EnvironmentSettings instance (uses global if None)
221            enabled: If False, disables wandb even if configured
222        """
223        self.settings = settings or EnvironmentSettings()
224        self.enabled = enabled and self.settings.use_wandb
225
226        if self.enabled and not self.settings.wandb_api_key:
227            logger.warning(
228                "wandb is enabled but no API key is set. "
229                "Set WANDB_API_KEY environment variable or disable wandb."
230            )
231            self.enabled = False
232
233        self.project = project or self.settings.wandb_project
234        self.entity = entity or self.settings.wandb_entity
235        self.name = name
236        self.config = config or {}
237        self.tags = tags
238
239        self._initialized = False
240        self._run = None

Initialize wandb configuration.

Arguments:
  • project: wandb project name (uses settings default if None)
  • entity: wandb entity (uses settings default if None)
  • name: Run name (auto-generated if None)
  • config: Configuration dict to log
  • tags: List of tags for the run
  • settings: EnvironmentSettings instance (uses global if None)
  • enabled: If False, disables wandb even if configured
settings
enabled
project
entity
name
config
tags
def initialize(self) -> bool:
242    def initialize(self) -> bool:
243        """
244        Initialize wandb run.
245
246        Returns:
247            True if wandb was successfully initialized, False otherwise
248        """
249        if not self.enabled:
250            logger.info("wandb is disabled")
251            return False
252
253        if self._initialized:
254            logger.warning("wandb is already initialized")
255            return True
256
257        try:
258            import wandb
259
260            # Set API key
261            os.environ["WANDB_API_KEY"] = self.settings.wandb_api_key
262            os.environ["WANDB_MODE"] = self.settings.wandb_mode
263
264            # Initialize run
265            self._run = wandb.init(
266                project=self.project,
267                entity=self.entity,
268                name=self.name,
269                config=self.config,
270                tags=self.tags,
271                reinit=True,  # Allow re-initialization
272            )
273
274            self._initialized = True
275            logger.info(f"wandb initialized: {self.get_run_url()}")
276            return True
277
278        except ImportError:
279            logger.warning(
280                "wandb package not installed. Install it with: pip install wandb"
281            )
282            self.enabled = False
283            return False
284        except Exception as e:
285            logger.error(f"Failed to initialize wandb: {e}")
286            self.enabled = False
287            return False

Initialize wandb run.

Returns:

True if wandb was successfully initialized, False otherwise

def finalize(self) -> None:
289    def finalize(self) -> None:
290        """Finalize wandb run."""
291        if self._initialized:
292            try:
293                import wandb
294
295                wandb.finish()
296                logger.info("wandb run finalized")
297            except Exception as e:
298                logger.error(f"Error finalizing wandb: {e}")
299            finally:
300                self._initialized = False
301                self._run = None

Finalize wandb run.

def log_metrics( self, metrics: dict, step: Optional[int] = None, commit: bool = True) -> None:
303    def log_metrics(
304        self,
305        metrics: dict,
306        step: Optional[int] = None,
307        commit: bool = True,
308    ) -> None:
309        """
310        Log metrics to wandb.
311
312        Args:
313            metrics: Dictionary of metric names to values
314            step: Current step (for logging to specific step)
315            commit: Whether to commit the log
316        """
317        if self._initialized:
318            try:
319                import wandb
320
321                wandb.log(metrics, step=step, commit=commit)
322            except Exception as e:
323                logger.error(f"Error logging to wandb: {e}")

Log metrics to wandb.

Arguments:
  • metrics: Dictionary of metric names to values
  • step: Current step (for logging to specific step)
  • commit: Whether to commit the log
def log_histogram(self, values, name: str, step: Optional[int] = None) -> None:
325    def log_histogram(
326        self,
327        values,
328        name: str,
329        step: Optional[int] = None,
330    ) -> None:
331        """
332        Log a histogram to wandb.
333
334        Args:
335            values: Array-like values to histogram
336            name: Name of the histogram
337            step: Current step
338        """
339        if self._initialized:
340            try:
341                import wandb
342                import numpy as np
343
344                wandb.log({name: wandb.Histogram(np.array(values))}, step=step)
345            except Exception as e:
346                logger.error(f"Error logging histogram to wandb: {e}")

Log a histogram to wandb.

Arguments:
  • values: Array-like values to histogram
  • name: Name of the histogram
  • step: Current step
def log_model(self, model_path: str, name: str = 'model') -> None:
348    def log_model(self, model_path: str, name: str = "model") -> None:
349        """
350        Log a model artifact to wandb.
351
352        Supports both single files and directories.
353
354        Args:
355            model_path: Path to the model file or directory
356            name: Artifact name
357        """
358        if self._initialized:
359            try:
360                import wandb
361                from pathlib import Path
362
363                artifact = wandb.Artifact(name, type="model")
364                path = Path(model_path)
365                if path.is_dir():
366                    artifact.add_dir(str(path), name=path.name)
367                else:
368                    artifact.add_file(str(path))
369                wandb.log_artifact(artifact)
370                logger.info(f"Logged model artifact: {name} from {model_path}")
371            except Exception as e:
372                logger.error(f"Error logging model to wandb: {e}")

Log a model artifact to wandb.

Supports both single files and directories.

Arguments:
  • model_path: Path to the model file or directory
  • name: Artifact name
def log_artifact(self, path: str, name: str, artifact_type: str = 'dataset') -> None:
374    def log_artifact(
375        self, path: str, name: str, artifact_type: str = "dataset"
376    ) -> None:
377        """
378        Log a dataset artifact to wandb.
379
380        Supports both single files and directories.
381
382        Args:
383            path: Path to the file or directory
384            name: Artifact name
385            artifact_type: Type of artifact (dataset, model, config, etc.)
386        """
387        if self._initialized:
388            try:
389                import wandb
390                from pathlib import Path
391
392                artifact = wandb.Artifact(name, type=artifact_type)
393                p = Path(path)
394                if p.is_dir():
395                    artifact.add_dir(str(p), name=p.name)
396                else:
397                    artifact.add_file(str(p))
398                wandb.log_artifact(artifact)
399                logger.info(f"Logged {artifact_type} artifact: {name} from {path}")
400            except Exception as e:
401                logger.error(f"Error logging artifact to wandb: {e}")

Log a dataset artifact to wandb.

Supports both single files and directories.

Arguments:
  • path: Path to the file or directory
  • name: Artifact name
  • artifact_type: Type of artifact (dataset, model, config, etc.)
def get_run_url(self) -> Optional[str]:
403    def get_run_url(self) -> Optional[str]:
404        """Get the wandb run URL for the current experiment.
405
406        Returns:
407            The URL string of the active wandb run, or ``None`` if
408            no run has been initialized.
409        """
410        if self._run:
411            return self._run.url
412        return None

Get the wandb run URL for the current experiment.

Returns:

The URL string of the active wandb run, or None if no run has been initialized.

def get_run_id(self) -> Optional[str]:
414    def get_run_id(self) -> Optional[str]:
415        """Get the wandb run ID for the current experiment.
416
417        Returns:
418            The unique run identifier string, or ``None`` if no run
419            has been initialized.
420        """
421        if self._run:
422            return self._run.id
423        return None

Get the wandb run ID for the current experiment.

Returns:

The unique run identifier string, or None if no run has been initialized.

def get_settings() -> EnvironmentSettings:
454def get_settings() -> EnvironmentSettings:
455    """
456    Get the global EnvironmentSettings instance.
457
458    Returns:
459        The global settings (creates one if it doesn't exist)
460    """
461    global _global_settings
462    if _global_settings is None:
463        _global_settings = EnvironmentSettings()
464    return _global_settings

Get the global EnvironmentSettings instance.

Returns:

The global settings (creates one if it doesn't exist)