drrik
Drrik: A Framework for Monosemantic Feature Extraction from Language Models
This framework is inspired by the Towards Monosemanticity paper from
Anthropic, which applies dictionary learning to extract activations as
features from transformer-based large language models and trains sparse
autoencoders to linearize those activations.
Key Components:
- Model loading from HuggingFace Hub (with gated model support)
- Dataset loading and inference pipeline
- MLP activation collection using nnsight
- Sparse Autoencoder training for feature extraction
- SAE-based activation steering to bias model outputs
- Visualization of feature-specific activation vectors
- Optional Weights & Biases integration for experiment tracking
Public API:
ActivationExtractorLoads models and datasets, extracts MLP activations via nnsight.SparseAutoencoderOvercomplete SAE with L1 regularization and dead neuron resampling.FeatureVisualizerGenerates density histograms, training curves, and feature dashboards.SAESteeringSteers language model generation by adding SAE feature directions to MLP activations during inference.ConfigTop-level Pydantic settings model aggregating all sub-configurations.EnvironmentSettingsLoads API keys and environment variables from.env.WandbConfigContext manager for wandb experiment tracking.get_settingsReturns the globalEnvironmentSettingssingleton.
Example Usage:
from drrik import ActivationExtractor, SparseAutoencoder, FeatureVisualizer # Extract activations extractor = ActivationExtractor( model_name="google/gemma-2b", dataset_name="wikitext", dataset_split="train", mlp_layers=[0, 1, 2], num_samples=1000 ) activations = extractor.extract() # Train sparse autoencoder (with optional wandb logging) sae = SparseAutoencoder( activation_dim=activations.shape[-1], hidden_dim=activations.shape[-1] * 8, # 8x expansion l1_coefficient=0.01 ) sae.fit(activations, wandb_enabled=True) # Visualize features (with optional wandb logging) visualizer = FeatureVisualizer(sae, activations) visualizer.plot_feature_density() visualizer.plot_top_activating_examples()
1""" 2Drrik: A Framework for Monosemantic Feature Extraction from Language Models 3 4This framework is inspired by the ``Towards Monosemanticity`` paper from 5Anthropic, which applies dictionary learning to extract activations as 6features from transformer-based large language models and trains sparse 7autoencoders to linearize those activations. 8 9Key Components: 10 - **Model loading** from HuggingFace Hub (with gated model support) 11 - **Dataset loading** and inference pipeline 12 - **MLP activation collection** using nnsight 13 - **Sparse Autoencoder training** for feature extraction 14 - **SAE-based activation steering** to bias model outputs 15 - **Visualization** of feature-specific activation vectors 16 - Optional **Weights & Biases** integration for experiment tracking 17 18Public API: 19 ``ActivationExtractor`` 20 Loads models and datasets, extracts MLP activations via nnsight. 21 ``SparseAutoencoder`` 22 Overcomplete SAE with L1 regularization and dead neuron resampling. 23 ``FeatureVisualizer`` 24 Generates density histograms, training curves, and feature dashboards. 25 ``SAESteering`` 26 Steers language model generation by adding SAE feature directions 27 to MLP activations during inference. 28 ``Config`` 29 Top-level Pydantic settings model aggregating all sub-configurations. 30 ``EnvironmentSettings`` 31 Loads API keys and environment variables from ``.env``. 32 ``WandbConfig`` 33 Context manager for wandb experiment tracking. 34 ``get_settings`` 35 Returns the global ``EnvironmentSettings`` singleton. 36 37Example Usage: 38 ```python 39 from drrik import ActivationExtractor, SparseAutoencoder, FeatureVisualizer 40 41 # Extract activations 42 extractor = ActivationExtractor( 43 model_name="google/gemma-2b", 44 dataset_name="wikitext", 45 dataset_split="train", 46 mlp_layers=[0, 1, 2], 47 num_samples=1000 48 ) 49 activations = extractor.extract() 50 51 # Train sparse autoencoder (with optional wandb logging) 52 sae = SparseAutoencoder( 53 activation_dim=activations.shape[-1], 54 hidden_dim=activations.shape[-1] * 8, # 8x expansion 55 l1_coefficient=0.01 56 ) 57 sae.fit(activations, wandb_enabled=True) 58 59 # Visualize features (with optional wandb logging) 60 visualizer = FeatureVisualizer(sae, activations) 61 visualizer.plot_feature_density() 62 visualizer.plot_top_activating_examples() 63 ``` 64""" 65 66__version__ = "0.1.0" 67 68from drrik.models import ActivationExtractor 69from drrik.autoencoder import SparseAutoencoder 70from drrik.visualization import FeatureVisualizer 71from drrik.steering import SAESteering 72from drrik.config import Config 73from drrik.settings import EnvironmentSettings, WandbConfig, get_settings 74 75__all__ = [ 76 "ActivationExtractor", 77 "SparseAutoencoder", 78 "FeatureVisualizer", 79 "SAESteering", 80 "Config", 81 "EnvironmentSettings", 82 "WandbConfig", 83 "get_settings", 84]
45class ActivationExtractor: 46 """ 47 Extract MLP activations from language models using nnsight. 48 49 Handles model and dataset loading from HuggingFace Hub, runs 50 inference with nnsight's tracing context to capture MLP outputs, 51 and provides persistence via ``.npy`` / ``.pkl`` files. 52 53 The HuggingFace auth token is read automatically from the 54 ``HUGGINGFACE_HUB_TOKEN`` environment variable (or ``.env`` file) 55 via :func:`~drrik.settings.get_settings`. 56 57 Attributes: 58 config: The resolved :class:`~drrik.config.ActivationExtractorConfig`. 59 model: The loaded nnsight ``LanguageModel`` (``None`` until 60 :meth:`load_model` is called). 61 tokenizer: The HuggingFace tokenizer (``None`` until 62 :meth:`load_model` is called). 63 dataset: The loaded HuggingFace ``Dataset`` (``None`` until 64 :meth:`load_dataset` is called). 65 66 Example: 67 ```python 68 extractor = ActivationExtractor( 69 model_name="google/gemma-2b", 70 dataset_name="wikitext", 71 dataset_config="wikitext-2-raw-v1", 72 mlp_layers=[0], 73 num_samples=1000, 74 ) 75 activations, metadata = extractor.extract() 76 extractor.save_activations(activations, metadata, "./output") 77 ``` 78 """ 79 80 def __init__(self, config: Optional[ActivationExtractorConfig] = None, **kwargs): 81 """ 82 Initialize the ActivationExtractor. 83 84 Accepts either a pre-built config object or flat keyword 85 arguments that are automatically sorted into 86 :class:`~drrik.config.ModelConfig`, 87 :class:`~drrik.config.DatasetConfig`, and the remaining 88 extraction-level settings. 89 90 Args: 91 config: A fully constructed 92 :class:`~drrik.config.ActivationExtractorConfig`. 93 If ``None``, the config is built from ``**kwargs``. 94 **kwargs: Flat config overrides. Recognised keys are 95 distributed to the appropriate sub-config: 96 97 - *Model*: ``model_name``, ``revision``, 98 ``torch_dtype``, ``device_map``, 99 ``trust_remote_code`` 100 - *Dataset*: ``dataset_name``, ``dataset_config``, 101 ``split``, ``num_samples``, ``text_column``, 102 ``max_length`` 103 - *Extraction*: ``mlp_layers``, ``batch_size``, 104 ``output_dir`` 105 """ 106 if config is None: 107 model_kwargs = { 108 k: v 109 for k, v in kwargs.items() 110 if k 111 in ( 112 "model_name", 113 "revision", 114 "torch_dtype", 115 "device_map", 116 "trust_remote_code", 117 ) 118 } 119 dataset_kwargs = { 120 k: v 121 for k, v in kwargs.items() 122 if k 123 in ( 124 "dataset_name", 125 "dataset_config", 126 "split", 127 "num_samples", 128 "text_column", 129 "max_length", 130 ) 131 } 132 remaining_kwargs = { 133 k: v 134 for k, v in kwargs.items() 135 if k not in model_kwargs and k not in dataset_kwargs 136 } 137 138 model = ModelConfig(**model_kwargs) if model_kwargs else None 139 dataset = DatasetConfig(**dataset_kwargs) if dataset_kwargs else None 140 141 config = ActivationExtractorConfig( 142 model=model, 143 dataset=dataset, 144 **remaining_kwargs, 145 ) 146 147 self.config = config 148 self.model = None 149 self.tokenizer = None 150 self.dataset = None 151 self._activations = [] 152 self._metadata = [] 153 154 def load_model(self) -> LanguageModel: 155 """ 156 Load the model from HuggingFace Hub using nnsight. 157 158 If the model is already loaded, returns the cached instance. 159 The HuggingFace token is read from ``HUGGINGFACE_HUB_TOKEN`` 160 via :func:`~drrik.settings.get_settings`. 161 162 Returns: 163 The loaded nnsight :class:`~nnsight.LanguageModel` wrapper. 164 165 Raises: 166 RuntimeError: If model loading fails. 167 """ 168 if self.model is not None: 169 return self.model 170 171 try: 172 logger.info(f"Loading model: {self.config.model.model_name}") 173 174 # Get HF token from settings 175 settings = get_settings() 176 hf_token = settings.huggingface_hub_token 177 178 if hf_token: 179 logger.info("Using HuggingFace Hub token for authentication") 180 181 # Load tokenizer with token 182 tokenizer_kwargs = { 183 "revision": self.config.model.revision, 184 "trust_remote_code": self.config.model.trust_remote_code, 185 } 186 if hf_token: 187 tokenizer_kwargs["token"] = hf_token 188 189 self.tokenizer = AutoTokenizer.from_pretrained( 190 self.config.model.model_name, **tokenizer_kwargs 191 ) 192 193 if self.tokenizer.pad_token is None: 194 self.tokenizer.pad_token = self.tokenizer.eos_token 195 196 # Determine dtype 197 dtype_map = { 198 "float16": torch.float16, 199 "bfloat16": torch.bfloat16, 200 "float32": torch.float32, 201 } 202 torch_dtype = dtype_map.get(self.config.model.torch_dtype, torch.float16) 203 204 # Load model with nnsight (with token if available) 205 nnsight_kwargs = { 206 "revision": self.config.model.revision, 207 "torch_dtype": torch_dtype, 208 "trust_remote_code": self.config.model.trust_remote_code, 209 "device_map": self.config.model.device_map, 210 } 211 if hf_token: 212 nnsight_kwargs["token"] = hf_token 213 214 self.model = LanguageModel(self.config.model.model_name, **nnsight_kwargs) 215 216 logger.info(f"Model loaded successfully on {self.model.device}") 217 return self.model 218 219 except Exception as e: 220 logger.error(f"Failed to load model: {e}") 221 raise RuntimeError(f"Model loading failed: {e}") from e 222 223 def load_dataset(self) -> Dataset: 224 """ 225 Load the dataset from HuggingFace Hub. 226 227 If the dataset is already loaded, returns the cached instance. 228 The HuggingFace token is forwarded for gated datasets when 229 available. 230 231 Returns: 232 The loaded HuggingFace :class:`~datasets.Dataset`. 233 234 Raises: 235 RuntimeError: If dataset loading fails. 236 """ 237 if self.dataset is not None: 238 return self.dataset 239 240 try: 241 logger.info( 242 f"Loading dataset: {self.config.dataset.dataset_name} " 243 f"({self.config.dataset.split} split)" 244 ) 245 246 # Get HF token from settings (for gated datasets) 247 settings = get_settings() 248 hf_token = settings.huggingface_hub_token 249 250 # Load dataset with token if available 251 load_kwargs = { 252 "path": self.config.dataset.dataset_name, 253 "name": self.config.dataset.dataset_config, 254 "split": self.config.dataset.split, 255 } 256 if hf_token: 257 load_kwargs["token"] = hf_token 258 259 self.dataset = load_dataset(**load_kwargs) 260 261 logger.info(f"Dataset loaded with {len(self.dataset)} examples") 262 return self.dataset 263 264 except Exception as e: 265 logger.error(f"Failed to load dataset: {e}") 266 raise RuntimeError(f"Dataset loading failed: {e}") from e 267 268 def _get_mlp_layer_name(self, layer_idx: int) -> str: 269 """ 270 Get the nnsight module path to an MLP layer. 271 272 Auto-detects the correct path pattern based on the model name 273 in the config. For unknown architectures a warning is logged 274 and the default Gemma/Llama pattern is returned. 275 276 Args: 277 layer_idx: Zero-indexed transformer layer number. 278 279 Returns: 280 A dotted/bracket module path string, e.g. 281 ``"model.layers[3].mlp"``. 282 """ 283 model_name_lower = self.config.model.model_name.lower() 284 285 # Gemma/GPT-style: model.layers.N.mlp 286 if any(name in model_name_lower for name in ["gemma", "gpt-2", "pythia"]): 287 return f"model.layers[{layer_idx}].mlp" 288 289 # Llama-style: model.layers.N.mlp 290 if "llama" in model_name_lower: 291 return f"model.layers[{layer_idx}].mlp" 292 293 # Phi-style: model.layers.N.mlp 294 if "phi" in model_name_lower: 295 return f"model.layers[{layer_idx}].mllp" # Note: mlp can be 'mlp' or 'mlp' 296 297 # BERT-style: bert.encoder.layer.N.output.dense 298 if "bert" in model_name_lower: 299 return f"bert.encoder.layer[{layer_idx}].output" 300 301 # Default: try common patterns 302 common_patterns = [ 303 f"model.layers[{layer_idx}].mlp", 304 f"model.layers[{layer_idx}].ffn", 305 f"transformer.h[{layer_idx}].mlp", 306 f"layers[{layer_idx}].mlp", 307 ] 308 309 logger.warning( 310 f"Unknown model architecture '{self.config.model.model_name}', " 311 f"trying common patterns" 312 ) 313 return common_patterns[0] 314 315 def _resolve_layer_path(self, path: str): 316 """Resolve a dotted/bracket path to the actual nnsight module. 317 318 Delegates to the shared :func:`~drrik.steering.resolve_module_path` 319 utility so that path-resolution logic is maintained in one place. 320 321 Args: 322 path: A dot-separated module path with optional bracket 323 indexing, e.g., ``model.layers[2].mlp``. 324 325 Returns: 326 The resolved nnsight module object corresponding to the 327 given path. 328 """ 329 return resolve_module_path(self.model, path) 330 331 def extract( 332 self, 333 num_samples: Optional[int] = None, 334 ) -> Tuple[np.ndarray, Dict[str, Any]]: 335 """ 336 Extract MLP activations from the model. 337 338 Loads the model and dataset (if not already loaded), tokenizes 339 the data, and runs batched forward passes using nnsight's 340 tracing context. For each sample, the **last-token** activation 341 from each requested MLP layer is collected and concatenated. 342 343 Args: 344 num_samples: Override for the number of samples to process. 345 If ``None``, uses ``config.dataset.num_samples``. 346 347 Returns: 348 A tuple of: 349 350 - **activations** (:class:`numpy.ndarray`) — shape 351 ``(n_samples, activation_dim * n_layers)``. When a 352 single layer is requested the shape simplifies to 353 ``(n_samples, activation_dim)``. 354 - **metadata** (:class:`dict`) — contains the serialised 355 config, sample count, activation dimension, layer paths, 356 and per-sample metadata (index, truncated text, input 357 ids). 358 359 Raises: 360 RuntimeError: If extraction fails at any stage. 361 """ 362 try: 363 # Load model and dataset 364 self.load_model() 365 self.load_dataset() 366 367 n_samples = num_samples or self.config.dataset.num_samples 368 logger.info( 369 f"Extracting activations from {len(self.config.mlp_layers)} MLP layers " 370 f"for {n_samples} samples" 371 ) 372 373 # Prepare dataset 374 dataset = self.dataset.select(range(min(n_samples, len(self.dataset)))) 375 376 # Tokenize dataset 377 def tokenize_function(examples): 378 return self.tokenizer( 379 examples[self.config.dataset.text_column], 380 padding="max_length", 381 truncation=True, 382 max_length=self.config.dataset.max_length, 383 return_tensors="pt", 384 ) 385 386 tokenized = dataset.map( 387 tokenize_function, 388 batched=True, 389 remove_columns=dataset.column_names, 390 desc="Tokenizing", 391 ) 392 393 # Get MLP layer names 394 layer_paths = [ 395 self._get_mlp_layer_name(layer_idx) 396 for layer_idx in self.config.mlp_layers 397 ] 398 399 logger.info(f"Extracting from layers: {layer_paths}") 400 401 # Collect activations using nnsight 402 self._activations = [] 403 self._metadata = [] 404 405 batch_size = self.config.batch_size 406 n_batches = (len(tokenized) + batch_size - 1) // batch_size 407 408 with torch.no_grad(): 409 for batch_idx in tqdm(range(n_batches), desc="Extracting activations"): 410 start_idx = batch_idx * batch_size 411 end_idx = min(start_idx + batch_size, len(tokenized)) 412 413 batch = tokenized[start_idx:end_idx] 414 input_ids = torch.tensor(batch["input_ids"]) 415 attention_mask = torch.tensor(batch["attention_mask"]) 416 417 layer_outputs = [] 418 419 # Use nnsight to extract activations 420 with self.model.trace(input_ids, attention_mask=attention_mask): 421 for layer_path in layer_paths: 422 module = self._resolve_layer_path(layer_path) 423 output = module.output.save() 424 layer_outputs.append(output) 425 426 # Process outputs 427 batch_input_ids = input_ids.cpu().numpy() 428 for sample_idx in range(len(input_ids)): 429 sample_activations = [] 430 for layer_output in layer_outputs: 431 # Get activation after non-linearity (if applicable) 432 # Shape: (seq_len, hidden_dim) or (batch, seq_len, hidden_dim) 433 act = layer_output 434 435 if act.dim() == 3: 436 act = act[sample_idx] # (seq_len, hidden_dim) 437 elif act.dim() == 2: 438 act = act # Already (seq_len, hidden_dim) 439 440 # Use the last token's activation (common practice) 441 act = act[-1] # (hidden_dim,) 442 443 sample_activations.append(act.cpu().numpy()) 444 445 # Concatenate activations from all layers 446 self._activations.append(np.concatenate(sample_activations)) 447 448 # Store metadata 449 self._metadata.append( 450 { 451 "sample_idx": start_idx + sample_idx, 452 "text": dataset[start_idx + sample_idx][ 453 self.config.dataset.text_column 454 ][:200], 455 "input_ids": batch_input_ids[sample_idx], 456 } 457 ) 458 459 activations = np.array(self._activations) 460 logger.info(f"Extracted activations shape: {activations.shape}") 461 462 metadata = { 463 "config": self.config.model_dump(), 464 "n_samples": len(self._activations), 465 "activation_dim": activations.shape[-1], 466 "layer_paths": layer_paths, 467 "samples_metadata": self._metadata, 468 } 469 470 return activations, metadata 471 472 except Exception as e: 473 logger.error(f"Failed to extract activations: {e}") 474 raise RuntimeError(f"Activation extraction failed: {e}") from e 475 476 def save_activations( 477 self, 478 activations: np.ndarray, 479 metadata: Dict[str, Any], 480 output_dir: Optional[Union[str, Path]] = None, 481 ) -> Path: 482 """ 483 Save extracted activations to disk. 484 485 Writes two files to *output_dir*: 486 487 - ``activations.npy`` — the activation array. 488 - ``metadata.pkl`` — the metadata dictionary. 489 490 Args: 491 activations: Activation array of shape 492 ``(n_samples, dim)``. 493 metadata: Metadata dictionary (as returned by 494 :meth:`extract`). 495 output_dir: Target directory. If ``None``, uses 496 ``config.output_dir``. 497 498 Returns: 499 Path to the saved ``activations.npy`` file. 500 501 Raises: 502 ValueError: If *output_dir* is ``None`` and no output 503 directory is configured. 504 """ 505 if output_dir is None: 506 output_dir = self.config.output_dir 507 508 output_dir = Path(output_dir) 509 output_dir.mkdir(parents=True, exist_ok=True) 510 511 activations_path = output_dir / "activations.npy" 512 np.save(str(activations_path), activations) 513 514 metadata_path = output_dir / "metadata.pkl" 515 with open(metadata_path, "wb") as f: 516 pickle.dump(metadata, f) 517 518 logger.info(f"Saved activations to {activations_path}") 519 return activations_path 520 521 def load_activations( 522 self, 523 filepath: Union[str, Path], 524 ) -> Tuple[np.ndarray, Dict[str, Any]]: 525 """ 526 Load saved activations from disk. 527 528 Supports two formats: 529 530 - ``.npy`` — loads the array directly and looks for a 531 companion ``metadata.pkl`` in the same directory. 532 - ``.pkl`` (legacy) — loads a dict with ``"activations"`` 533 and ``"metadata"`` keys. 534 535 Args: 536 filepath: Path to an ``.npy`` or ``.pkl`` file. 537 538 Returns: 539 A tuple of (activations array, metadata dict). 540 541 Raises: 542 FileNotFoundError: If the file (or companion metadata) does 543 not exist. 544 """ 545 filepath = Path(filepath) 546 547 if filepath.suffix == ".npy": 548 activations = np.load(str(filepath)) 549 metadata_path = filepath.parent / "metadata.pkl" 550 with open(metadata_path, "rb") as f: 551 metadata = pickle.load(f) 552 else: 553 with open(filepath, "rb") as f: 554 data = pickle.load(f) 555 activations = data["activations"] 556 metadata = data["metadata"] 557 558 logger.info(f"Loaded activations from {filepath}") 559 return activations, metadata
Extract MLP activations from language models using nnsight.
Handles model and dataset loading from HuggingFace Hub, runs
inference with nnsight's tracing context to capture MLP outputs,
and provides persistence via .npy / .pkl files.
The HuggingFace auth token is read automatically from the
HUGGINGFACE_HUB_TOKEN environment variable (or .env file)
via ~drrik.settings.get_settings().
Attributes:
- config: The resolved
~drrik.config.ActivationExtractorConfig. - model: The loaded nnsight
LanguageModel(Noneuntilload_model()is called). - tokenizer: The HuggingFace tokenizer (
Noneuntilload_model()is called). - dataset: The loaded HuggingFace
Dataset(Noneuntilload_dataset()is called).
Example:
extractor = ActivationExtractor( model_name="google/gemma-2b", dataset_name="wikitext", dataset_config="wikitext-2-raw-v1", mlp_layers=[0], num_samples=1000, ) activations, metadata = extractor.extract() extractor.save_activations(activations, metadata, "./output")
80 def __init__(self, config: Optional[ActivationExtractorConfig] = None, **kwargs): 81 """ 82 Initialize the ActivationExtractor. 83 84 Accepts either a pre-built config object or flat keyword 85 arguments that are automatically sorted into 86 :class:`~drrik.config.ModelConfig`, 87 :class:`~drrik.config.DatasetConfig`, and the remaining 88 extraction-level settings. 89 90 Args: 91 config: A fully constructed 92 :class:`~drrik.config.ActivationExtractorConfig`. 93 If ``None``, the config is built from ``**kwargs``. 94 **kwargs: Flat config overrides. Recognised keys are 95 distributed to the appropriate sub-config: 96 97 - *Model*: ``model_name``, ``revision``, 98 ``torch_dtype``, ``device_map``, 99 ``trust_remote_code`` 100 - *Dataset*: ``dataset_name``, ``dataset_config``, 101 ``split``, ``num_samples``, ``text_column``, 102 ``max_length`` 103 - *Extraction*: ``mlp_layers``, ``batch_size``, 104 ``output_dir`` 105 """ 106 if config is None: 107 model_kwargs = { 108 k: v 109 for k, v in kwargs.items() 110 if k 111 in ( 112 "model_name", 113 "revision", 114 "torch_dtype", 115 "device_map", 116 "trust_remote_code", 117 ) 118 } 119 dataset_kwargs = { 120 k: v 121 for k, v in kwargs.items() 122 if k 123 in ( 124 "dataset_name", 125 "dataset_config", 126 "split", 127 "num_samples", 128 "text_column", 129 "max_length", 130 ) 131 } 132 remaining_kwargs = { 133 k: v 134 for k, v in kwargs.items() 135 if k not in model_kwargs and k not in dataset_kwargs 136 } 137 138 model = ModelConfig(**model_kwargs) if model_kwargs else None 139 dataset = DatasetConfig(**dataset_kwargs) if dataset_kwargs else None 140 141 config = ActivationExtractorConfig( 142 model=model, 143 dataset=dataset, 144 **remaining_kwargs, 145 ) 146 147 self.config = config 148 self.model = None 149 self.tokenizer = None 150 self.dataset = None 151 self._activations = [] 152 self._metadata = []
Initialize the ActivationExtractor.
Accepts either a pre-built config object or flat keyword
arguments that are automatically sorted into
~drrik.config.ModelConfig,
~drrik.config.DatasetConfig, and the remaining
extraction-level settings.
Arguments:
- config: A fully constructed
~drrik.config.ActivationExtractorConfig. IfNone, the config is built from**kwargs. **kwargs: Flat config overrides. Recognised keys are distributed to the appropriate sub-config:
- Model:
model_name,revision,torch_dtype,device_map,trust_remote_code - Dataset:
dataset_name,dataset_config,split,num_samples,text_column,max_length - Extraction:
mlp_layers,batch_size,output_dir
- Model:
154 def load_model(self) -> LanguageModel: 155 """ 156 Load the model from HuggingFace Hub using nnsight. 157 158 If the model is already loaded, returns the cached instance. 159 The HuggingFace token is read from ``HUGGINGFACE_HUB_TOKEN`` 160 via :func:`~drrik.settings.get_settings`. 161 162 Returns: 163 The loaded nnsight :class:`~nnsight.LanguageModel` wrapper. 164 165 Raises: 166 RuntimeError: If model loading fails. 167 """ 168 if self.model is not None: 169 return self.model 170 171 try: 172 logger.info(f"Loading model: {self.config.model.model_name}") 173 174 # Get HF token from settings 175 settings = get_settings() 176 hf_token = settings.huggingface_hub_token 177 178 if hf_token: 179 logger.info("Using HuggingFace Hub token for authentication") 180 181 # Load tokenizer with token 182 tokenizer_kwargs = { 183 "revision": self.config.model.revision, 184 "trust_remote_code": self.config.model.trust_remote_code, 185 } 186 if hf_token: 187 tokenizer_kwargs["token"] = hf_token 188 189 self.tokenizer = AutoTokenizer.from_pretrained( 190 self.config.model.model_name, **tokenizer_kwargs 191 ) 192 193 if self.tokenizer.pad_token is None: 194 self.tokenizer.pad_token = self.tokenizer.eos_token 195 196 # Determine dtype 197 dtype_map = { 198 "float16": torch.float16, 199 "bfloat16": torch.bfloat16, 200 "float32": torch.float32, 201 } 202 torch_dtype = dtype_map.get(self.config.model.torch_dtype, torch.float16) 203 204 # Load model with nnsight (with token if available) 205 nnsight_kwargs = { 206 "revision": self.config.model.revision, 207 "torch_dtype": torch_dtype, 208 "trust_remote_code": self.config.model.trust_remote_code, 209 "device_map": self.config.model.device_map, 210 } 211 if hf_token: 212 nnsight_kwargs["token"] = hf_token 213 214 self.model = LanguageModel(self.config.model.model_name, **nnsight_kwargs) 215 216 logger.info(f"Model loaded successfully on {self.model.device}") 217 return self.model 218 219 except Exception as e: 220 logger.error(f"Failed to load model: {e}") 221 raise RuntimeError(f"Model loading failed: {e}") from e
Load the model from HuggingFace Hub using nnsight.
If the model is already loaded, returns the cached instance.
The HuggingFace token is read from HUGGINGFACE_HUB_TOKEN
via ~drrik.settings.get_settings().
Returns:
The loaded nnsight
~nnsight.LanguageModelwrapper.
Raises:
- RuntimeError: If model loading fails.
223 def load_dataset(self) -> Dataset: 224 """ 225 Load the dataset from HuggingFace Hub. 226 227 If the dataset is already loaded, returns the cached instance. 228 The HuggingFace token is forwarded for gated datasets when 229 available. 230 231 Returns: 232 The loaded HuggingFace :class:`~datasets.Dataset`. 233 234 Raises: 235 RuntimeError: If dataset loading fails. 236 """ 237 if self.dataset is not None: 238 return self.dataset 239 240 try: 241 logger.info( 242 f"Loading dataset: {self.config.dataset.dataset_name} " 243 f"({self.config.dataset.split} split)" 244 ) 245 246 # Get HF token from settings (for gated datasets) 247 settings = get_settings() 248 hf_token = settings.huggingface_hub_token 249 250 # Load dataset with token if available 251 load_kwargs = { 252 "path": self.config.dataset.dataset_name, 253 "name": self.config.dataset.dataset_config, 254 "split": self.config.dataset.split, 255 } 256 if hf_token: 257 load_kwargs["token"] = hf_token 258 259 self.dataset = load_dataset(**load_kwargs) 260 261 logger.info(f"Dataset loaded with {len(self.dataset)} examples") 262 return self.dataset 263 264 except Exception as e: 265 logger.error(f"Failed to load dataset: {e}") 266 raise RuntimeError(f"Dataset loading failed: {e}") from e
Load the dataset from HuggingFace Hub.
If the dataset is already loaded, returns the cached instance. The HuggingFace token is forwarded for gated datasets when available.
Returns:
The loaded HuggingFace
~datasets.Dataset.
Raises:
- RuntimeError: If dataset loading fails.
331 def extract( 332 self, 333 num_samples: Optional[int] = None, 334 ) -> Tuple[np.ndarray, Dict[str, Any]]: 335 """ 336 Extract MLP activations from the model. 337 338 Loads the model and dataset (if not already loaded), tokenizes 339 the data, and runs batched forward passes using nnsight's 340 tracing context. For each sample, the **last-token** activation 341 from each requested MLP layer is collected and concatenated. 342 343 Args: 344 num_samples: Override for the number of samples to process. 345 If ``None``, uses ``config.dataset.num_samples``. 346 347 Returns: 348 A tuple of: 349 350 - **activations** (:class:`numpy.ndarray`) — shape 351 ``(n_samples, activation_dim * n_layers)``. When a 352 single layer is requested the shape simplifies to 353 ``(n_samples, activation_dim)``. 354 - **metadata** (:class:`dict`) — contains the serialised 355 config, sample count, activation dimension, layer paths, 356 and per-sample metadata (index, truncated text, input 357 ids). 358 359 Raises: 360 RuntimeError: If extraction fails at any stage. 361 """ 362 try: 363 # Load model and dataset 364 self.load_model() 365 self.load_dataset() 366 367 n_samples = num_samples or self.config.dataset.num_samples 368 logger.info( 369 f"Extracting activations from {len(self.config.mlp_layers)} MLP layers " 370 f"for {n_samples} samples" 371 ) 372 373 # Prepare dataset 374 dataset = self.dataset.select(range(min(n_samples, len(self.dataset)))) 375 376 # Tokenize dataset 377 def tokenize_function(examples): 378 return self.tokenizer( 379 examples[self.config.dataset.text_column], 380 padding="max_length", 381 truncation=True, 382 max_length=self.config.dataset.max_length, 383 return_tensors="pt", 384 ) 385 386 tokenized = dataset.map( 387 tokenize_function, 388 batched=True, 389 remove_columns=dataset.column_names, 390 desc="Tokenizing", 391 ) 392 393 # Get MLP layer names 394 layer_paths = [ 395 self._get_mlp_layer_name(layer_idx) 396 for layer_idx in self.config.mlp_layers 397 ] 398 399 logger.info(f"Extracting from layers: {layer_paths}") 400 401 # Collect activations using nnsight 402 self._activations = [] 403 self._metadata = [] 404 405 batch_size = self.config.batch_size 406 n_batches = (len(tokenized) + batch_size - 1) // batch_size 407 408 with torch.no_grad(): 409 for batch_idx in tqdm(range(n_batches), desc="Extracting activations"): 410 start_idx = batch_idx * batch_size 411 end_idx = min(start_idx + batch_size, len(tokenized)) 412 413 batch = tokenized[start_idx:end_idx] 414 input_ids = torch.tensor(batch["input_ids"]) 415 attention_mask = torch.tensor(batch["attention_mask"]) 416 417 layer_outputs = [] 418 419 # Use nnsight to extract activations 420 with self.model.trace(input_ids, attention_mask=attention_mask): 421 for layer_path in layer_paths: 422 module = self._resolve_layer_path(layer_path) 423 output = module.output.save() 424 layer_outputs.append(output) 425 426 # Process outputs 427 batch_input_ids = input_ids.cpu().numpy() 428 for sample_idx in range(len(input_ids)): 429 sample_activations = [] 430 for layer_output in layer_outputs: 431 # Get activation after non-linearity (if applicable) 432 # Shape: (seq_len, hidden_dim) or (batch, seq_len, hidden_dim) 433 act = layer_output 434 435 if act.dim() == 3: 436 act = act[sample_idx] # (seq_len, hidden_dim) 437 elif act.dim() == 2: 438 act = act # Already (seq_len, hidden_dim) 439 440 # Use the last token's activation (common practice) 441 act = act[-1] # (hidden_dim,) 442 443 sample_activations.append(act.cpu().numpy()) 444 445 # Concatenate activations from all layers 446 self._activations.append(np.concatenate(sample_activations)) 447 448 # Store metadata 449 self._metadata.append( 450 { 451 "sample_idx": start_idx + sample_idx, 452 "text": dataset[start_idx + sample_idx][ 453 self.config.dataset.text_column 454 ][:200], 455 "input_ids": batch_input_ids[sample_idx], 456 } 457 ) 458 459 activations = np.array(self._activations) 460 logger.info(f"Extracted activations shape: {activations.shape}") 461 462 metadata = { 463 "config": self.config.model_dump(), 464 "n_samples": len(self._activations), 465 "activation_dim": activations.shape[-1], 466 "layer_paths": layer_paths, 467 "samples_metadata": self._metadata, 468 } 469 470 return activations, metadata 471 472 except Exception as e: 473 logger.error(f"Failed to extract activations: {e}") 474 raise RuntimeError(f"Activation extraction failed: {e}") from e
Extract MLP activations from the model.
Loads the model and dataset (if not already loaded), tokenizes the data, and runs batched forward passes using nnsight's tracing context. For each sample, the last-token activation from each requested MLP layer is collected and concatenated.
Arguments:
- num_samples: Override for the number of samples to process.
If
None, usesconfig.dataset.num_samples.
Returns:
A tuple of:
- activations (
numpy.ndarray) — shape(n_samples, activation_dim * n_layers). When a single layer is requested the shape simplifies to(n_samples, activation_dim).- metadata (
dict) — contains the serialised config, sample count, activation dimension, layer paths, and per-sample metadata (index, truncated text, input ids).
Raises:
- RuntimeError: If extraction fails at any stage.
476 def save_activations( 477 self, 478 activations: np.ndarray, 479 metadata: Dict[str, Any], 480 output_dir: Optional[Union[str, Path]] = None, 481 ) -> Path: 482 """ 483 Save extracted activations to disk. 484 485 Writes two files to *output_dir*: 486 487 - ``activations.npy`` — the activation array. 488 - ``metadata.pkl`` — the metadata dictionary. 489 490 Args: 491 activations: Activation array of shape 492 ``(n_samples, dim)``. 493 metadata: Metadata dictionary (as returned by 494 :meth:`extract`). 495 output_dir: Target directory. If ``None``, uses 496 ``config.output_dir``. 497 498 Returns: 499 Path to the saved ``activations.npy`` file. 500 501 Raises: 502 ValueError: If *output_dir* is ``None`` and no output 503 directory is configured. 504 """ 505 if output_dir is None: 506 output_dir = self.config.output_dir 507 508 output_dir = Path(output_dir) 509 output_dir.mkdir(parents=True, exist_ok=True) 510 511 activations_path = output_dir / "activations.npy" 512 np.save(str(activations_path), activations) 513 514 metadata_path = output_dir / "metadata.pkl" 515 with open(metadata_path, "wb") as f: 516 pickle.dump(metadata, f) 517 518 logger.info(f"Saved activations to {activations_path}") 519 return activations_path
Save extracted activations to disk.
Writes two files to output_dir:
activations.npy— the activation array.metadata.pkl— the metadata dictionary.
Arguments:
- activations: Activation array of shape
(n_samples, dim). - metadata: Metadata dictionary (as returned by
extract()). - output_dir: Target directory. If
None, usesconfig.output_dir.
Returns:
Path to the saved
activations.npyfile.
Raises:
- ValueError: If output_dir is
Noneand no output directory is configured.
521 def load_activations( 522 self, 523 filepath: Union[str, Path], 524 ) -> Tuple[np.ndarray, Dict[str, Any]]: 525 """ 526 Load saved activations from disk. 527 528 Supports two formats: 529 530 - ``.npy`` — loads the array directly and looks for a 531 companion ``metadata.pkl`` in the same directory. 532 - ``.pkl`` (legacy) — loads a dict with ``"activations"`` 533 and ``"metadata"`` keys. 534 535 Args: 536 filepath: Path to an ``.npy`` or ``.pkl`` file. 537 538 Returns: 539 A tuple of (activations array, metadata dict). 540 541 Raises: 542 FileNotFoundError: If the file (or companion metadata) does 543 not exist. 544 """ 545 filepath = Path(filepath) 546 547 if filepath.suffix == ".npy": 548 activations = np.load(str(filepath)) 549 metadata_path = filepath.parent / "metadata.pkl" 550 with open(metadata_path, "rb") as f: 551 metadata = pickle.load(f) 552 else: 553 with open(filepath, "rb") as f: 554 data = pickle.load(f) 555 activations = data["activations"] 556 metadata = data["metadata"] 557 558 logger.info(f"Loaded activations from {filepath}") 559 return activations, metadata
Load saved activations from disk.
Supports two formats:
.npy— loads the array directly and looks for a companionmetadata.pklin the same directory..pkl(legacy) — loads a dict with"activations"and"metadata"keys.
Arguments:
- filepath: Path to an
.npyor.pklfile.
Returns:
A tuple of (activations array, metadata dict).
Raises:
- FileNotFoundError: If the file (or companion metadata) does not exist.
47class SparseAutoencoder(nn.Module): 48 """ 49 Sparse Autoencoder for extracting interpretable features from MLP activations. 50 51 Architecture: 52 Input (activation_dim) -> [bias] -> Encoder (hidden_dim) -> ReLU -> [bias] 53 -> Decoder (activation_dim) -> Output 54 55 The decoder weights are normalized to unit norm to prevent scaling collapse. 56 L1 regularization on hidden activations encourages sparsity. 57 58 Based on the architecture described in: 59 "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning" 60 https://transformer-circuits.pub/2023/monosemantic-features/index.html 61 62 Attributes: 63 encoder: Linear layer from activation_dim to hidden_dim 64 decoder: Linear layer from hidden_dim to activation_dim 65 pre_encoder_bias: Learnable bias subtracted from input 66 post_decoder_bias: Learnable bias added to output (tied to pre_encoder_bias) 67 68 Example: 69 ```python 70 sae = SparseAutoencoder( 71 activation_dim=2048, 72 hidden_dim=4096, # 2x expansion 73 l1_coefficient=0.01 74 ) 75 sae.fit(activations, num_epochs=100) 76 features = sae.encode(activations) 77 reconstructed = sae.decode(features) 78 ``` 79 """ 80 81 def __init__( 82 self, 83 activation_dim: int, 84 hidden_dim: int, 85 l1_coefficient: float = 0.01, 86 normalize_decoder: bool = True, 87 pre_encoder_bias: bool = True, 88 ): 89 """ 90 Initialize the Sparse Autoencoder. 91 92 Args: 93 activation_dim: Dimension of input MLP activations 94 hidden_dim: Dimension of sparse hidden layer (overcomplete) 95 l1_coefficient: L1 regularization strength 96 normalize_decoder: Whether to normalize decoder columns to unit norm 97 pre_encoder_bias: Whether to use pre-encoder bias (as in the paper) 98 """ 99 super().__init__() 100 101 self.activation_dim = activation_dim 102 self.hidden_dim = hidden_dim 103 self.l1_coefficient = l1_coefficient 104 self.normalize_decoder = normalize_decoder 105 self.pre_encoder_bias = pre_encoder_bias 106 107 # Encoder: input -> hidden 108 self.encoder = nn.Linear(activation_dim, hidden_dim, bias=False) 109 110 # Decoder: hidden -> output 111 # Initialize with small random weights 112 self.decoder = nn.Linear(hidden_dim, activation_dim, bias=False) 113 114 # Initialize decoder weights to unit norm (Kaiming uniform) 115 nn.init.kaiming_uniform_(self.decoder.weight, a=np.sqrt(5)) 116 with torch.no_grad(): 117 self.decoder.weight.copy_( 118 self.decoder.weight / self.decoder.weight.norm(dim=0, keepdim=True) 119 ) 120 121 # Initialize encoder weights 122 nn.init.kaiming_uniform_(self.encoder.weight, a=np.sqrt(5)) 123 124 # Biases 125 if pre_encoder_bias: 126 # Pre-encoder bias (subtracted from input, added to output) 127 # Initialize to geometric median of data (will be set during fit) 128 self.bias = nn.Parameter(torch.zeros(activation_dim)) 129 else: 130 self.register_parameter("bias", None) 131 132 # Encoder bias 133 self.encoder_bias = nn.Parameter(torch.zeros(hidden_dim)) 134 135 # Training metrics 136 self.training_losses = [] 137 self.training_l0_norms = [] 138 139 def encode(self, x: torch.Tensor) -> torch.Tensor: 140 """ 141 Encode input activations to sparse features. 142 143 Args: 144 x: Input tensor of shape (batch_size, activation_dim) 145 146 Returns: 147 Sparse feature activations of shape (batch_size, hidden_dim) 148 """ 149 # Subtract bias if using pre-encoder bias 150 if self.bias is not None: 151 x = x - self.bias 152 153 # Apply encoder 154 hidden = F.linear(x, self.encoder.weight, self.encoder_bias) 155 156 # ReLU activation for sparsity 157 features = F.relu(hidden) 158 159 return features 160 161 def decode(self, features: torch.Tensor) -> torch.Tensor: 162 """ 163 Decode sparse features back to activation space. 164 165 Args: 166 features: Sparse feature tensor of shape (batch_size, hidden_dim) 167 168 Returns: 169 Reconstructed activations of shape (batch_size, activation_dim) 170 """ 171 # Apply decoder (weights are normalized to unit norm) 172 output = F.linear(features, self.decoder.weight) 173 174 # Add bias if using 175 if self.bias is not None: 176 output = output + self.bias 177 178 return output 179 180 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 181 """ 182 Forward pass through the autoencoder. 183 184 Args: 185 x: Input tensor of shape (batch_size, activation_dim) 186 187 Returns: 188 Tuple of (reconstructed, features) 189 """ 190 features = self.encode(x) 191 reconstructed = self.decode(features) 192 return reconstructed, features 193 194 def loss( 195 self, x: torch.Tensor, reconstructed: torch.Tensor, features: torch.Tensor 196 ) -> torch.Tensor: 197 """ 198 Compute the loss with L1 regularization. 199 200 Loss = MSE(reconstructed, x) + lambda * L1(features) 201 202 Args: 203 x: Input tensor 204 reconstructed: Reconstructed tensor 205 features: Sparse feature activations 206 207 Returns: 208 Total loss 209 """ 210 mse_loss = F.mse_loss(reconstructed, x) 211 l1_loss = self.l1_coefficient * features.abs().sum(dim=1).mean() 212 return mse_loss + l1_loss 213 214 def normalize_decoder_weights(self) -> None: 215 """Normalize decoder weight columns to unit L2 norm. 216 217 This prevents scaling collapse where the optimizer could 218 trivially reduce the loss by shrinking decoder weights and 219 scaling up encoder weights. Applied after each optimizer step 220 when ``normalize_decoder`` is ``True``. 221 """ 222 if self.normalize_decoder: 223 with torch.no_grad(): 224 self.decoder.weight.copy_( 225 self.decoder.weight 226 / self.decoder.weight.norm(dim=0, keepdim=True).clamp(min=1e-8) 227 ) 228 229 def resample_dead_neurons( 230 self, 231 activations: torch.Tensor, 232 dead_threshold: float = 1e-8, 233 dead_mask: Optional[torch.Tensor] = None, 234 ) -> int: 235 """ 236 Resample dead neurons that haven't fired recently. 237 238 This follows the procedure from the paper: 239 1. Identify neurons that haven't fired above threshold 240 2. Compute loss on a batch of data 241 3. Resample dead neurons to fit poorly reconstructed examples 242 243 Args: 244 activations: Batch of activations to use for resampling 245 dead_threshold: Activation threshold below which a neuron is considered dead 246 dead_mask: Pre-computed boolean mask of dead neurons. If None, computed 247 from the current batch. When provided, uses a sliding window 248 approach for more robust dead neuron detection. 249 250 Returns: 251 Number of neurons resampled 252 """ 253 with torch.no_grad(): 254 # Get feature activations 255 features = self.encode(activations) 256 257 # Find dead neurons 258 if dead_mask is None: 259 neuron_activity = features.abs().sum(dim=0) 260 dead_mask = neuron_activity < dead_threshold 261 n_dead = dead_mask.sum().item() 262 263 if n_dead == 0: 264 return 0 265 266 logger.info(f"Resampling {n_dead} dead neurons") 267 268 # Compute reconstruction loss for each sample 269 reconstructed, _ = self.forward(activations) 270 sample_losses = F.mse_loss( 271 reconstructed, activations, reduction="none" 272 ).sum(dim=1) 273 274 # Sample based on loss (higher loss = more likely to be resampled) 275 probs = sample_losses**2 276 probs = probs / probs.sum() 277 278 # For each dead neuron, resample from a high-loss example 279 for dead_idx in torch.where(dead_mask)[0]: 280 # Sample an example 281 sample_idx = torch.multinomial(probs, 1).item() 282 example = activations[sample_idx] 283 284 # Normalize example 285 example_normalized = example / (example.norm() + 1e-8) 286 287 # Set decoder weight 288 self.decoder.weight[:, dead_idx] = example_normalized 289 290 # Set encoder weight (smaller scale to prevent immediate firing) 291 avg_encoder_norm = ( 292 self.encoder.weight[~dead_mask].norm(dim=1).mean().item() 293 if (~dead_mask).any() > 0 294 else 0.1 295 ) 296 self.encoder.weight[dead_idx, :] = ( 297 example_normalized * avg_encoder_norm * 0.2 298 ) 299 300 # Reset encoder bias 301 self.encoder_bias[dead_idx] = 0 302 303 # Re-normalize decoder 304 self.normalize_decoder_weights() 305 306 return n_dead 307 308 def fit( 309 self, 310 activations: np.ndarray, 311 batch_size: int = 256, 312 num_epochs: int = 100, 313 learning_rate: float = 1e-4, 314 validation_split: float = 0.1, 315 resample_dead_neurons: bool = True, 316 resample_interval: int = 10000, 317 dead_threshold: float = 1e-8, 318 window_size: int = 100, 319 device: Optional[str] = None, 320 verbose: bool = True, 321 wandb_config: Optional[WandbConfig] = None, 322 wandb_enabled: bool = True, 323 ) -> "SparseAutoencoder": 324 """ 325 Train the sparse autoencoder on MLP activations. 326 327 Args: 328 activations: Training data of shape (n_samples, activation_dim) 329 batch_size: Batch size for training 330 num_epochs: Number of training epochs 331 learning_rate: Learning rate for Adam optimizer 332 validation_split: Fraction of data to use for validation 333 resample_dead_neurons: Whether to resample dead neurons during training 334 resample_interval: Steps between resampling checks 335 dead_threshold: Activation threshold below which a neuron is considered dead 336 window_size: Number of recent batches to track for dead neuron detection 337 device: Device to train on (cuda/cpu). If None, auto-detect 338 verbose: Whether to show progress bars 339 wandb_config: Optional WandbConfig for experiment tracking 340 wandb_enabled: If True and wandb_config is None, create default WandbConfig 341 342 Returns: 343 self (trained model) 344 """ 345 # Setup wandb if enabled 346 wandb_logger = None 347 _owns_wandb = False 348 if wandb_enabled: 349 if wandb_config is None: 350 wandb_logger = WandbConfig( 351 config={ 352 "activation_dim": self.activation_dim, 353 "hidden_dim": self.hidden_dim, 354 "expansion_factor": self.hidden_dim / self.activation_dim, 355 "l1_coefficient": self.l1_coefficient, 356 "batch_size": batch_size, 357 "learning_rate": learning_rate, 358 "num_epochs": num_epochs, 359 } 360 ) 361 wandb_logger.initialize() 362 _owns_wandb = True 363 elif not wandb_config._initialized: 364 wandb_logger = wandb_config 365 wandb_logger.initialize() 366 _owns_wandb = True 367 else: 368 wandb_logger = wandb_config 369 370 # Determine device 371 if device is None: 372 device = "cuda" if torch.cuda.is_available() else "cpu" 373 374 self.to(device) 375 376 # Convert to tensor 377 if isinstance(activations, np.ndarray): 378 activations = torch.from_numpy(activations).float() 379 380 n_samples = len(activations) 381 n_val = int(n_samples * validation_split) 382 n_train = n_samples - n_val 383 384 # Split data 385 train_data = activations[:n_train] 386 val_data = activations[n_train:] 387 388 # Create data loaders 389 train_dataset = TensorDataset(train_data) 390 train_loader = DataLoader( 391 train_dataset, 392 batch_size=batch_size, 393 shuffle=True, 394 ) 395 396 # Initialize bias to median if using 397 if self.bias is not None and self.pre_encoder_bias: 398 with torch.no_grad(): 399 # Approximate geometric median using median for simplicity 400 median = train_data.to(device).median(dim=0).values 401 self.bias.data = median 402 403 # Optimizer 404 optimizer = torch.optim.Adam( 405 self.parameters(), 406 lr=learning_rate, 407 ) 408 409 # Training loop 410 self.training_losses = [] 411 self.training_l0_norms = [] 412 413 n_steps_since_resample = 0 414 global_step = 0 415 416 # Sliding window for tracking neuron activity across batches 417 neuron_activity_window = [] 418 419 if verbose: 420 pbar = tqdm(total=num_epochs, desc="Training SAE") 421 else: 422 pbar = None 423 424 for epoch in range(num_epochs): 425 epoch_loss = 0.0 426 epoch_l0 = 0.0 427 428 for batch in train_loader: 429 batch = batch[0].to(device) 430 431 # Forward pass 432 reconstructed, features = self.forward(batch) 433 434 # Track neuron activity for sliding window dead detection 435 neuron_activity = features.abs().sum(dim=0) 436 neuron_activity_window.append(neuron_activity) 437 if len(neuron_activity_window) > window_size: 438 neuron_activity_window.pop(0) 439 440 # Compute loss 441 loss = self.loss(batch, reconstructed, features) 442 443 # Backward pass 444 optimizer.zero_grad() 445 loss.backward() 446 447 # Remove gradients parallel to decoder weights (Adam optimization trick from paper) 448 if self.normalize_decoder: 449 with torch.no_grad(): 450 # Project gradients to be orthogonal to decoder weights 451 decoder_w = self.decoder.weight 452 decoder_grad = self.decoder.weight.grad 453 454 # Compute projection 455 projection = (decoder_grad * decoder_w).sum( 456 dim=0, keepdim=True 457 ) * decoder_w 458 self.decoder.weight.grad = decoder_grad - projection 459 460 optimizer.step() 461 462 # Normalize decoder weights 463 self.normalize_decoder_weights() 464 465 # Track metrics 466 epoch_loss += loss.item() 467 epoch_l0 += (features > 0).sum(dim=1).float().mean().item() 468 469 n_steps_since_resample += 1 470 global_step += 1 471 472 # Resample dead neurons using sliding window 473 if ( 474 resample_dead_neurons 475 and n_steps_since_resample >= resample_interval 476 ): 477 if len(neuron_activity_window) > 0: 478 avg_activity = torch.stack(neuron_activity_window).mean(dim=0) 479 window_dead_mask = avg_activity < dead_threshold 480 else: 481 window_dead_mask = None 482 483 n_resampled = self.resample_dead_neurons( 484 batch, dead_threshold=dead_threshold, dead_mask=window_dead_mask 485 ) 486 if n_resampled > 0: 487 neuron_activity_window.clear() 488 n_steps_since_resample = 0 489 490 # Average metrics 491 avg_loss = epoch_loss / len(train_loader) 492 avg_l0 = epoch_l0 / len(train_loader) 493 494 self.training_losses.append(avg_loss) 495 self.training_l0_norms.append(avg_l0) 496 497 # Validation 498 with torch.no_grad(): 499 val_data_tensor = val_data.to(device) 500 val_reconstructed, val_features = self.forward(val_data_tensor) 501 val_loss = F.mse_loss(val_reconstructed, val_data_tensor).item() 502 val_l0 = (val_features > 0).sum(dim=1).float().mean().item() 503 504 # Log to wandb 505 if wandb_logger: 506 wandb_logger.log_metrics( 507 { 508 "epoch": epoch + 1, 509 "train_loss": avg_loss, 510 "train_l0_norm": avg_l0, 511 "val_loss": val_loss, 512 "val_l0_norm": val_l0, 513 "learning_rate": learning_rate, 514 }, 515 step=epoch + 1, 516 ) 517 518 if verbose: 519 if pbar: 520 pbar.update(1) 521 pbar.set_postfix( 522 { 523 "loss": f"{avg_loss:.6f}", 524 "val_loss": f"{val_loss:.6f}", 525 "L0": f"{avg_l0:.2f}", 526 "val_L0": f"{val_l0:.2f}", 527 } 528 ) 529 else: 530 logger.info( 531 f"Epoch {epoch + 1}/{num_epochs}: " 532 f"loss={avg_loss:.6f}, val_loss={val_loss:.6f}, " 533 f"L0={avg_l0:.2f}, val_L0={val_l0:.2f}" 534 ) 535 536 if pbar: 537 pbar.close() 538 539 # Finalize wandb and log summary metrics 540 if wandb_logger: 541 # Log final feature densities 542 densities = self.get_feature_density(activations.cpu().numpy()) 543 n_dead = (densities == 0).sum() 544 n_active = (densities > 0).sum() 545 546 wandb_logger.log_metrics( 547 { 548 "final_train_loss": self.training_losses[-1], 549 "final_val_loss": val_loss, 550 "final_l0_norm": self.training_l0_norms[-1], 551 "final_val_l0_norm": val_l0, 552 "n_dead_features": int(n_dead), 553 "n_active_features": int(n_active), 554 "feature_sparsity": float(n_active / len(densities)), 555 } 556 ) 557 558 # Log feature density histogram 559 wandb_logger.log_histogram(densities, "feature_density_histogram") 560 561 logger.info(f"wandb run URL: {wandb_logger.get_run_url()}") 562 563 # Finalize wandb (close the run) only if we created it 564 if _owns_wandb: 565 wandb_logger.finalize() 566 567 logger.info("Training complete!") 568 return self 569 570 def get_feature_density(self, activations: np.ndarray) -> np.ndarray: 571 """ 572 Compute the density (fraction of non-zero activations) for each feature. 573 574 Args: 575 activations: Data to compute density on 576 577 Returns: 578 Array of feature densities of shape (hidden_dim,) 579 """ 580 self.eval() 581 device = next(self.parameters()).device 582 with torch.no_grad(): 583 if isinstance(activations, np.ndarray): 584 activations = torch.from_numpy(activations).float().to(device) 585 else: 586 activations = activations.to(device) 587 588 features = self.encode(activations) 589 density = (features > 0).float().mean(dim=0).cpu().numpy() 590 591 return density 592 593 def get_top_activating_examples( 594 self, 595 activations: np.ndarray, 596 feature_idx: int, 597 k: int = 10, 598 ) -> Tuple[np.ndarray, np.ndarray]: 599 """ 600 Get the top k examples that activate a given feature the most. 601 602 Args: 603 activations: Data to search through 604 feature_idx: Index of the feature to analyze 605 k: Number of top examples to return 606 607 Returns: 608 Tuple of (top activations values, top indices) 609 """ 610 self.eval() 611 device = next(self.parameters()).device 612 with torch.no_grad(): 613 if isinstance(activations, np.ndarray): 614 activations = torch.from_numpy(activations).float().to(device) 615 else: 616 activations = activations.to(device) 617 618 features = self.encode(activations) 619 feature_activations = features[:, feature_idx].cpu().numpy() 620 621 top_k_indices = np.argsort(feature_activations)[-k:][::-1] 622 top_k_values = feature_activations[top_k_indices] 623 624 return top_k_values, top_k_indices 625 626 def save(self, filepath: Union[str, Path]) -> None: 627 """Save model state, hyperparameters, and training history to disk. 628 629 Saves a dictionary containing the model architecture parameters, 630 state dict, and training loss/L0 history as a PyTorch ``.pt`` file. 631 The file can be loaded with ``SparseAutoencoder.load()``. 632 633 Args: 634 filepath: Path where the model file will be saved. Parent 635 directories are created automatically if they don't exist. 636 """ 637 filepath = Path(filepath) 638 filepath.parent.mkdir(parents=True, exist_ok=True) 639 640 state = { 641 "activation_dim": self.activation_dim, 642 "hidden_dim": self.hidden_dim, 643 "l1_coefficient": self.l1_coefficient, 644 "state_dict": self.state_dict(), 645 "training_losses": self.training_losses, 646 "training_l0_norms": self.training_l0_norms, 647 } 648 649 torch.save(state, filepath) 650 logger.info(f"Saved model to {filepath}") 651 652 @classmethod 653 def load(cls, filepath: Union[str, Path]) -> "SparseAutoencoder": 654 """Load a saved SparseAutoencoder from disk. 655 656 Reconstructs the model from a ``.pt`` file saved by ``save()``, 657 restoring the architecture, learned weights, and training 658 history (losses and L0 norms). 659 660 Args: 661 filepath: Path to the saved ``.pt`` model file. 662 663 Returns: 664 A ``SparseAutoencoder`` instance with restored weights and 665 training metrics. 666 """ 667 filepath = Path(filepath) 668 state = torch.load(filepath, weights_only=False) 669 670 model = cls( 671 activation_dim=state["activation_dim"], 672 hidden_dim=state["hidden_dim"], 673 l1_coefficient=state["l1_coefficient"], 674 ) 675 model.load_state_dict(state["state_dict"]) 676 model.training_losses = state.get("training_losses", []) 677 model.training_l0_norms = state.get("training_l0_norms", []) 678 679 logger.info(f"Loaded model from {filepath}") 680 return model
Sparse Autoencoder for extracting interpretable features from MLP activations.
Architecture:
Input (activation_dim) -> [bias] -> Encoder (hidden_dim) -> ReLU -> [bias] -> Decoder (activation_dim) -> Output
The decoder weights are normalized to unit norm to prevent scaling collapse. L1 regularization on hidden activations encourages sparsity.
Based on the architecture described in: "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning" https://transformer-circuits.pub/2023/monosemantic-features/index.html
Attributes:
- encoder: Linear layer from activation_dim to hidden_dim
- decoder: Linear layer from hidden_dim to activation_dim
- pre_encoder_bias: Learnable bias subtracted from input
- post_decoder_bias: Learnable bias added to output (tied to pre_encoder_bias)
Example:
sae = SparseAutoencoder( activation_dim=2048, hidden_dim=4096, # 2x expansion l1_coefficient=0.01 ) sae.fit(activations, num_epochs=100) features = sae.encode(activations) reconstructed = sae.decode(features)
81 def __init__( 82 self, 83 activation_dim: int, 84 hidden_dim: int, 85 l1_coefficient: float = 0.01, 86 normalize_decoder: bool = True, 87 pre_encoder_bias: bool = True, 88 ): 89 """ 90 Initialize the Sparse Autoencoder. 91 92 Args: 93 activation_dim: Dimension of input MLP activations 94 hidden_dim: Dimension of sparse hidden layer (overcomplete) 95 l1_coefficient: L1 regularization strength 96 normalize_decoder: Whether to normalize decoder columns to unit norm 97 pre_encoder_bias: Whether to use pre-encoder bias (as in the paper) 98 """ 99 super().__init__() 100 101 self.activation_dim = activation_dim 102 self.hidden_dim = hidden_dim 103 self.l1_coefficient = l1_coefficient 104 self.normalize_decoder = normalize_decoder 105 self.pre_encoder_bias = pre_encoder_bias 106 107 # Encoder: input -> hidden 108 self.encoder = nn.Linear(activation_dim, hidden_dim, bias=False) 109 110 # Decoder: hidden -> output 111 # Initialize with small random weights 112 self.decoder = nn.Linear(hidden_dim, activation_dim, bias=False) 113 114 # Initialize decoder weights to unit norm (Kaiming uniform) 115 nn.init.kaiming_uniform_(self.decoder.weight, a=np.sqrt(5)) 116 with torch.no_grad(): 117 self.decoder.weight.copy_( 118 self.decoder.weight / self.decoder.weight.norm(dim=0, keepdim=True) 119 ) 120 121 # Initialize encoder weights 122 nn.init.kaiming_uniform_(self.encoder.weight, a=np.sqrt(5)) 123 124 # Biases 125 if pre_encoder_bias: 126 # Pre-encoder bias (subtracted from input, added to output) 127 # Initialize to geometric median of data (will be set during fit) 128 self.bias = nn.Parameter(torch.zeros(activation_dim)) 129 else: 130 self.register_parameter("bias", None) 131 132 # Encoder bias 133 self.encoder_bias = nn.Parameter(torch.zeros(hidden_dim)) 134 135 # Training metrics 136 self.training_losses = [] 137 self.training_l0_norms = []
Initialize the Sparse Autoencoder.
Arguments:
- activation_dim: Dimension of input MLP activations
- hidden_dim: Dimension of sparse hidden layer (overcomplete)
- l1_coefficient: L1 regularization strength
- normalize_decoder: Whether to normalize decoder columns to unit norm
- pre_encoder_bias: Whether to use pre-encoder bias (as in the paper)
139 def encode(self, x: torch.Tensor) -> torch.Tensor: 140 """ 141 Encode input activations to sparse features. 142 143 Args: 144 x: Input tensor of shape (batch_size, activation_dim) 145 146 Returns: 147 Sparse feature activations of shape (batch_size, hidden_dim) 148 """ 149 # Subtract bias if using pre-encoder bias 150 if self.bias is not None: 151 x = x - self.bias 152 153 # Apply encoder 154 hidden = F.linear(x, self.encoder.weight, self.encoder_bias) 155 156 # ReLU activation for sparsity 157 features = F.relu(hidden) 158 159 return features
Encode input activations to sparse features.
Arguments:
- x: Input tensor of shape (batch_size, activation_dim)
Returns:
Sparse feature activations of shape (batch_size, hidden_dim)
161 def decode(self, features: torch.Tensor) -> torch.Tensor: 162 """ 163 Decode sparse features back to activation space. 164 165 Args: 166 features: Sparse feature tensor of shape (batch_size, hidden_dim) 167 168 Returns: 169 Reconstructed activations of shape (batch_size, activation_dim) 170 """ 171 # Apply decoder (weights are normalized to unit norm) 172 output = F.linear(features, self.decoder.weight) 173 174 # Add bias if using 175 if self.bias is not None: 176 output = output + self.bias 177 178 return output
Decode sparse features back to activation space.
Arguments:
- features: Sparse feature tensor of shape (batch_size, hidden_dim)
Returns:
Reconstructed activations of shape (batch_size, activation_dim)
180 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 181 """ 182 Forward pass through the autoencoder. 183 184 Args: 185 x: Input tensor of shape (batch_size, activation_dim) 186 187 Returns: 188 Tuple of (reconstructed, features) 189 """ 190 features = self.encode(x) 191 reconstructed = self.decode(features) 192 return reconstructed, features
Forward pass through the autoencoder.
Arguments:
- x: Input tensor of shape (batch_size, activation_dim)
Returns:
Tuple of (reconstructed, features)
194 def loss( 195 self, x: torch.Tensor, reconstructed: torch.Tensor, features: torch.Tensor 196 ) -> torch.Tensor: 197 """ 198 Compute the loss with L1 regularization. 199 200 Loss = MSE(reconstructed, x) + lambda * L1(features) 201 202 Args: 203 x: Input tensor 204 reconstructed: Reconstructed tensor 205 features: Sparse feature activations 206 207 Returns: 208 Total loss 209 """ 210 mse_loss = F.mse_loss(reconstructed, x) 211 l1_loss = self.l1_coefficient * features.abs().sum(dim=1).mean() 212 return mse_loss + l1_loss
Compute the loss with L1 regularization.
Loss = MSE(reconstructed, x) + lambda * L1(features)
Arguments:
- x: Input tensor
- reconstructed: Reconstructed tensor
- features: Sparse feature activations
Returns:
Total loss
214 def normalize_decoder_weights(self) -> None: 215 """Normalize decoder weight columns to unit L2 norm. 216 217 This prevents scaling collapse where the optimizer could 218 trivially reduce the loss by shrinking decoder weights and 219 scaling up encoder weights. Applied after each optimizer step 220 when ``normalize_decoder`` is ``True``. 221 """ 222 if self.normalize_decoder: 223 with torch.no_grad(): 224 self.decoder.weight.copy_( 225 self.decoder.weight 226 / self.decoder.weight.norm(dim=0, keepdim=True).clamp(min=1e-8) 227 )
Normalize decoder weight columns to unit L2 norm.
This prevents scaling collapse where the optimizer could
trivially reduce the loss by shrinking decoder weights and
scaling up encoder weights. Applied after each optimizer step
when normalize_decoder is True.
229 def resample_dead_neurons( 230 self, 231 activations: torch.Tensor, 232 dead_threshold: float = 1e-8, 233 dead_mask: Optional[torch.Tensor] = None, 234 ) -> int: 235 """ 236 Resample dead neurons that haven't fired recently. 237 238 This follows the procedure from the paper: 239 1. Identify neurons that haven't fired above threshold 240 2. Compute loss on a batch of data 241 3. Resample dead neurons to fit poorly reconstructed examples 242 243 Args: 244 activations: Batch of activations to use for resampling 245 dead_threshold: Activation threshold below which a neuron is considered dead 246 dead_mask: Pre-computed boolean mask of dead neurons. If None, computed 247 from the current batch. When provided, uses a sliding window 248 approach for more robust dead neuron detection. 249 250 Returns: 251 Number of neurons resampled 252 """ 253 with torch.no_grad(): 254 # Get feature activations 255 features = self.encode(activations) 256 257 # Find dead neurons 258 if dead_mask is None: 259 neuron_activity = features.abs().sum(dim=0) 260 dead_mask = neuron_activity < dead_threshold 261 n_dead = dead_mask.sum().item() 262 263 if n_dead == 0: 264 return 0 265 266 logger.info(f"Resampling {n_dead} dead neurons") 267 268 # Compute reconstruction loss for each sample 269 reconstructed, _ = self.forward(activations) 270 sample_losses = F.mse_loss( 271 reconstructed, activations, reduction="none" 272 ).sum(dim=1) 273 274 # Sample based on loss (higher loss = more likely to be resampled) 275 probs = sample_losses**2 276 probs = probs / probs.sum() 277 278 # For each dead neuron, resample from a high-loss example 279 for dead_idx in torch.where(dead_mask)[0]: 280 # Sample an example 281 sample_idx = torch.multinomial(probs, 1).item() 282 example = activations[sample_idx] 283 284 # Normalize example 285 example_normalized = example / (example.norm() + 1e-8) 286 287 # Set decoder weight 288 self.decoder.weight[:, dead_idx] = example_normalized 289 290 # Set encoder weight (smaller scale to prevent immediate firing) 291 avg_encoder_norm = ( 292 self.encoder.weight[~dead_mask].norm(dim=1).mean().item() 293 if (~dead_mask).any() > 0 294 else 0.1 295 ) 296 self.encoder.weight[dead_idx, :] = ( 297 example_normalized * avg_encoder_norm * 0.2 298 ) 299 300 # Reset encoder bias 301 self.encoder_bias[dead_idx] = 0 302 303 # Re-normalize decoder 304 self.normalize_decoder_weights() 305 306 return n_dead
Resample dead neurons that haven't fired recently.
This follows the procedure from the paper:
- Identify neurons that haven't fired above threshold
- Compute loss on a batch of data
- Resample dead neurons to fit poorly reconstructed examples
Arguments:
- activations: Batch of activations to use for resampling
- dead_threshold: Activation threshold below which a neuron is considered dead
- dead_mask: Pre-computed boolean mask of dead neurons. If None, computed from the current batch. When provided, uses a sliding window approach for more robust dead neuron detection.
Returns:
Number of neurons resampled
308 def fit( 309 self, 310 activations: np.ndarray, 311 batch_size: int = 256, 312 num_epochs: int = 100, 313 learning_rate: float = 1e-4, 314 validation_split: float = 0.1, 315 resample_dead_neurons: bool = True, 316 resample_interval: int = 10000, 317 dead_threshold: float = 1e-8, 318 window_size: int = 100, 319 device: Optional[str] = None, 320 verbose: bool = True, 321 wandb_config: Optional[WandbConfig] = None, 322 wandb_enabled: bool = True, 323 ) -> "SparseAutoencoder": 324 """ 325 Train the sparse autoencoder on MLP activations. 326 327 Args: 328 activations: Training data of shape (n_samples, activation_dim) 329 batch_size: Batch size for training 330 num_epochs: Number of training epochs 331 learning_rate: Learning rate for Adam optimizer 332 validation_split: Fraction of data to use for validation 333 resample_dead_neurons: Whether to resample dead neurons during training 334 resample_interval: Steps between resampling checks 335 dead_threshold: Activation threshold below which a neuron is considered dead 336 window_size: Number of recent batches to track for dead neuron detection 337 device: Device to train on (cuda/cpu). If None, auto-detect 338 verbose: Whether to show progress bars 339 wandb_config: Optional WandbConfig for experiment tracking 340 wandb_enabled: If True and wandb_config is None, create default WandbConfig 341 342 Returns: 343 self (trained model) 344 """ 345 # Setup wandb if enabled 346 wandb_logger = None 347 _owns_wandb = False 348 if wandb_enabled: 349 if wandb_config is None: 350 wandb_logger = WandbConfig( 351 config={ 352 "activation_dim": self.activation_dim, 353 "hidden_dim": self.hidden_dim, 354 "expansion_factor": self.hidden_dim / self.activation_dim, 355 "l1_coefficient": self.l1_coefficient, 356 "batch_size": batch_size, 357 "learning_rate": learning_rate, 358 "num_epochs": num_epochs, 359 } 360 ) 361 wandb_logger.initialize() 362 _owns_wandb = True 363 elif not wandb_config._initialized: 364 wandb_logger = wandb_config 365 wandb_logger.initialize() 366 _owns_wandb = True 367 else: 368 wandb_logger = wandb_config 369 370 # Determine device 371 if device is None: 372 device = "cuda" if torch.cuda.is_available() else "cpu" 373 374 self.to(device) 375 376 # Convert to tensor 377 if isinstance(activations, np.ndarray): 378 activations = torch.from_numpy(activations).float() 379 380 n_samples = len(activations) 381 n_val = int(n_samples * validation_split) 382 n_train = n_samples - n_val 383 384 # Split data 385 train_data = activations[:n_train] 386 val_data = activations[n_train:] 387 388 # Create data loaders 389 train_dataset = TensorDataset(train_data) 390 train_loader = DataLoader( 391 train_dataset, 392 batch_size=batch_size, 393 shuffle=True, 394 ) 395 396 # Initialize bias to median if using 397 if self.bias is not None and self.pre_encoder_bias: 398 with torch.no_grad(): 399 # Approximate geometric median using median for simplicity 400 median = train_data.to(device).median(dim=0).values 401 self.bias.data = median 402 403 # Optimizer 404 optimizer = torch.optim.Adam( 405 self.parameters(), 406 lr=learning_rate, 407 ) 408 409 # Training loop 410 self.training_losses = [] 411 self.training_l0_norms = [] 412 413 n_steps_since_resample = 0 414 global_step = 0 415 416 # Sliding window for tracking neuron activity across batches 417 neuron_activity_window = [] 418 419 if verbose: 420 pbar = tqdm(total=num_epochs, desc="Training SAE") 421 else: 422 pbar = None 423 424 for epoch in range(num_epochs): 425 epoch_loss = 0.0 426 epoch_l0 = 0.0 427 428 for batch in train_loader: 429 batch = batch[0].to(device) 430 431 # Forward pass 432 reconstructed, features = self.forward(batch) 433 434 # Track neuron activity for sliding window dead detection 435 neuron_activity = features.abs().sum(dim=0) 436 neuron_activity_window.append(neuron_activity) 437 if len(neuron_activity_window) > window_size: 438 neuron_activity_window.pop(0) 439 440 # Compute loss 441 loss = self.loss(batch, reconstructed, features) 442 443 # Backward pass 444 optimizer.zero_grad() 445 loss.backward() 446 447 # Remove gradients parallel to decoder weights (Adam optimization trick from paper) 448 if self.normalize_decoder: 449 with torch.no_grad(): 450 # Project gradients to be orthogonal to decoder weights 451 decoder_w = self.decoder.weight 452 decoder_grad = self.decoder.weight.grad 453 454 # Compute projection 455 projection = (decoder_grad * decoder_w).sum( 456 dim=0, keepdim=True 457 ) * decoder_w 458 self.decoder.weight.grad = decoder_grad - projection 459 460 optimizer.step() 461 462 # Normalize decoder weights 463 self.normalize_decoder_weights() 464 465 # Track metrics 466 epoch_loss += loss.item() 467 epoch_l0 += (features > 0).sum(dim=1).float().mean().item() 468 469 n_steps_since_resample += 1 470 global_step += 1 471 472 # Resample dead neurons using sliding window 473 if ( 474 resample_dead_neurons 475 and n_steps_since_resample >= resample_interval 476 ): 477 if len(neuron_activity_window) > 0: 478 avg_activity = torch.stack(neuron_activity_window).mean(dim=0) 479 window_dead_mask = avg_activity < dead_threshold 480 else: 481 window_dead_mask = None 482 483 n_resampled = self.resample_dead_neurons( 484 batch, dead_threshold=dead_threshold, dead_mask=window_dead_mask 485 ) 486 if n_resampled > 0: 487 neuron_activity_window.clear() 488 n_steps_since_resample = 0 489 490 # Average metrics 491 avg_loss = epoch_loss / len(train_loader) 492 avg_l0 = epoch_l0 / len(train_loader) 493 494 self.training_losses.append(avg_loss) 495 self.training_l0_norms.append(avg_l0) 496 497 # Validation 498 with torch.no_grad(): 499 val_data_tensor = val_data.to(device) 500 val_reconstructed, val_features = self.forward(val_data_tensor) 501 val_loss = F.mse_loss(val_reconstructed, val_data_tensor).item() 502 val_l0 = (val_features > 0).sum(dim=1).float().mean().item() 503 504 # Log to wandb 505 if wandb_logger: 506 wandb_logger.log_metrics( 507 { 508 "epoch": epoch + 1, 509 "train_loss": avg_loss, 510 "train_l0_norm": avg_l0, 511 "val_loss": val_loss, 512 "val_l0_norm": val_l0, 513 "learning_rate": learning_rate, 514 }, 515 step=epoch + 1, 516 ) 517 518 if verbose: 519 if pbar: 520 pbar.update(1) 521 pbar.set_postfix( 522 { 523 "loss": f"{avg_loss:.6f}", 524 "val_loss": f"{val_loss:.6f}", 525 "L0": f"{avg_l0:.2f}", 526 "val_L0": f"{val_l0:.2f}", 527 } 528 ) 529 else: 530 logger.info( 531 f"Epoch {epoch + 1}/{num_epochs}: " 532 f"loss={avg_loss:.6f}, val_loss={val_loss:.6f}, " 533 f"L0={avg_l0:.2f}, val_L0={val_l0:.2f}" 534 ) 535 536 if pbar: 537 pbar.close() 538 539 # Finalize wandb and log summary metrics 540 if wandb_logger: 541 # Log final feature densities 542 densities = self.get_feature_density(activations.cpu().numpy()) 543 n_dead = (densities == 0).sum() 544 n_active = (densities > 0).sum() 545 546 wandb_logger.log_metrics( 547 { 548 "final_train_loss": self.training_losses[-1], 549 "final_val_loss": val_loss, 550 "final_l0_norm": self.training_l0_norms[-1], 551 "final_val_l0_norm": val_l0, 552 "n_dead_features": int(n_dead), 553 "n_active_features": int(n_active), 554 "feature_sparsity": float(n_active / len(densities)), 555 } 556 ) 557 558 # Log feature density histogram 559 wandb_logger.log_histogram(densities, "feature_density_histogram") 560 561 logger.info(f"wandb run URL: {wandb_logger.get_run_url()}") 562 563 # Finalize wandb (close the run) only if we created it 564 if _owns_wandb: 565 wandb_logger.finalize() 566 567 logger.info("Training complete!") 568 return self
Train the sparse autoencoder on MLP activations.
Arguments:
- activations: Training data of shape (n_samples, activation_dim)
- batch_size: Batch size for training
- num_epochs: Number of training epochs
- learning_rate: Learning rate for Adam optimizer
- validation_split: Fraction of data to use for validation
- resample_dead_neurons: Whether to resample dead neurons during training
- resample_interval: Steps between resampling checks
- dead_threshold: Activation threshold below which a neuron is considered dead
- window_size: Number of recent batches to track for dead neuron detection
- device: Device to train on (cuda/cpu). If None, auto-detect
- verbose: Whether to show progress bars
- wandb_config: Optional WandbConfig for experiment tracking
- wandb_enabled: If True and wandb_config is None, create default WandbConfig
Returns:
self (trained model)
570 def get_feature_density(self, activations: np.ndarray) -> np.ndarray: 571 """ 572 Compute the density (fraction of non-zero activations) for each feature. 573 574 Args: 575 activations: Data to compute density on 576 577 Returns: 578 Array of feature densities of shape (hidden_dim,) 579 """ 580 self.eval() 581 device = next(self.parameters()).device 582 with torch.no_grad(): 583 if isinstance(activations, np.ndarray): 584 activations = torch.from_numpy(activations).float().to(device) 585 else: 586 activations = activations.to(device) 587 588 features = self.encode(activations) 589 density = (features > 0).float().mean(dim=0).cpu().numpy() 590 591 return density
Compute the density (fraction of non-zero activations) for each feature.
Arguments:
- activations: Data to compute density on
Returns:
Array of feature densities of shape (hidden_dim,)
593 def get_top_activating_examples( 594 self, 595 activations: np.ndarray, 596 feature_idx: int, 597 k: int = 10, 598 ) -> Tuple[np.ndarray, np.ndarray]: 599 """ 600 Get the top k examples that activate a given feature the most. 601 602 Args: 603 activations: Data to search through 604 feature_idx: Index of the feature to analyze 605 k: Number of top examples to return 606 607 Returns: 608 Tuple of (top activations values, top indices) 609 """ 610 self.eval() 611 device = next(self.parameters()).device 612 with torch.no_grad(): 613 if isinstance(activations, np.ndarray): 614 activations = torch.from_numpy(activations).float().to(device) 615 else: 616 activations = activations.to(device) 617 618 features = self.encode(activations) 619 feature_activations = features[:, feature_idx].cpu().numpy() 620 621 top_k_indices = np.argsort(feature_activations)[-k:][::-1] 622 top_k_values = feature_activations[top_k_indices] 623 624 return top_k_values, top_k_indices
Get the top k examples that activate a given feature the most.
Arguments:
- activations: Data to search through
- feature_idx: Index of the feature to analyze
- k: Number of top examples to return
Returns:
Tuple of (top activations values, top indices)
626 def save(self, filepath: Union[str, Path]) -> None: 627 """Save model state, hyperparameters, and training history to disk. 628 629 Saves a dictionary containing the model architecture parameters, 630 state dict, and training loss/L0 history as a PyTorch ``.pt`` file. 631 The file can be loaded with ``SparseAutoencoder.load()``. 632 633 Args: 634 filepath: Path where the model file will be saved. Parent 635 directories are created automatically if they don't exist. 636 """ 637 filepath = Path(filepath) 638 filepath.parent.mkdir(parents=True, exist_ok=True) 639 640 state = { 641 "activation_dim": self.activation_dim, 642 "hidden_dim": self.hidden_dim, 643 "l1_coefficient": self.l1_coefficient, 644 "state_dict": self.state_dict(), 645 "training_losses": self.training_losses, 646 "training_l0_norms": self.training_l0_norms, 647 } 648 649 torch.save(state, filepath) 650 logger.info(f"Saved model to {filepath}")
Save model state, hyperparameters, and training history to disk.
Saves a dictionary containing the model architecture parameters,
state dict, and training loss/L0 history as a PyTorch .pt file.
The file can be loaded with SparseAutoencoder.load().
Arguments:
- filepath: Path where the model file will be saved. Parent directories are created automatically if they don't exist.
652 @classmethod 653 def load(cls, filepath: Union[str, Path]) -> "SparseAutoencoder": 654 """Load a saved SparseAutoencoder from disk. 655 656 Reconstructs the model from a ``.pt`` file saved by ``save()``, 657 restoring the architecture, learned weights, and training 658 history (losses and L0 norms). 659 660 Args: 661 filepath: Path to the saved ``.pt`` model file. 662 663 Returns: 664 A ``SparseAutoencoder`` instance with restored weights and 665 training metrics. 666 """ 667 filepath = Path(filepath) 668 state = torch.load(filepath, weights_only=False) 669 670 model = cls( 671 activation_dim=state["activation_dim"], 672 hidden_dim=state["hidden_dim"], 673 l1_coefficient=state["l1_coefficient"], 674 ) 675 model.load_state_dict(state["state_dict"]) 676 model.training_losses = state.get("training_losses", []) 677 model.training_l0_norms = state.get("training_l0_norms", []) 678 679 logger.info(f"Loaded model from {filepath}") 680 return model
Load a saved SparseAutoencoder from disk.
Reconstructs the model from a .pt file saved by save(),
restoring the architecture, learned weights, and training
history (losses and L0 norms).
Arguments:
- filepath: Path to the saved
.ptmodel file.
Returns:
A
SparseAutoencoderinstance with restored weights and training metrics.
41class FeatureVisualizer: 42 """ 43 Visualize sparse autoencoder features and activations. 44 45 This class provides methods to create various plots that help understand 46 the learned features, similar to the visualizations in the 47 "Towards Monosemanticity" paper. 48 49 Example: 50 ```python 51 visualizer = FeatureVisualizer(sae, activations, metadata) 52 visualizer.plot_feature_density() 53 visualizer.plot_top_features(n_features=10) 54 visualizer.plot_feature_examples(feature_idx=0) 55 ``` 56 """ 57 58 def __init__( 59 self, 60 sae: SparseAutoencoder, 61 activations: np.ndarray, 62 metadata: Optional[Dict[str, Any]] = None, 63 output_dir: Union[str, Path] = "./visualizations", 64 style: str = "whitegrid", 65 dpi: int = 150, 66 wandb_config: Optional[WandbConfig] = None, 67 log_to_wandb: bool = True, 68 ): 69 """ 70 Initialize the visualizer. 71 72 Args: 73 sae: Trained sparse autoencoder 74 activations: The MLP activations used to train the SAE 75 metadata: Optional metadata dictionary (e.g., from ActivationExtractor) 76 output_dir: Directory to save plots 77 style: Seaborn style to use 78 dpi: DPI for saved figures 79 wandb_config: Optional WandbConfig for logging visualizations 80 log_to_wandb: If True, log plots to wandb when wandb_config is provided 81 """ 82 self.sae = sae 83 self.activations = activations 84 self.metadata = metadata or {} 85 self.output_dir = Path(output_dir) 86 self.dpi = dpi 87 self.wandb_config = wandb_config 88 self.log_to_wandb = log_to_wandb 89 90 # Set style 91 sns.set_style(style) 92 sns.set_palette("husl") 93 94 # Create output directory 95 self.output_dir.mkdir(parents=True, exist_ok=True) 96 97 # Compute feature activations 98 self._compute_features() 99 100 def _log_figure_to_wandb(self, fig: plt.Figure, name: str) -> None: 101 """ 102 Log a matplotlib figure to wandb. 103 104 Args: 105 fig: The matplotlib figure 106 name: Name for the plot in wandb 107 """ 108 if self.wandb_config and self.log_to_wandb: 109 try: 110 import wandb 111 112 wandb.log({name: wandb.Image(fig)}) 113 logger.debug(f"Logged {name} to wandb") 114 except Exception as e: 115 logger.warning(f"Failed to log {name} to wandb: {e}") 116 117 def _compute_features(self) -> None: 118 """Compute sparse feature activations from the SAE. 119 120 Runs the encoder forward pass on the stored activations to 121 produce a matrix of feature activations. The result is stored 122 in ``self.features`` as a numpy array of shape 123 ``(n_samples, hidden_dim)`` and is used by all plotting methods. 124 """ 125 import torch 126 127 self.sae.eval() 128 device = next(self.sae.parameters()).device 129 with torch.no_grad(): 130 activations_tensor = torch.from_numpy(self.activations).float().to(device) 131 self.features = self.sae.encode(activations_tensor).cpu().numpy() 132 133 logger.info(f"Computed features with shape: {self.features.shape}") 134 135 def plot_feature_density( 136 self, 137 bins: int = 50, 138 log_scale: bool = True, 139 save_path: Optional[str] = None, 140 ) -> plt.Figure: 141 """ 142 Plot the distribution of feature densities. 143 144 Feature density is the fraction of examples on which a feature fires. 145 This histogram helps understand the sparsity distribution. 146 147 Args: 148 bins: Number of histogram bins 149 log_scale: Whether to use log scale on x-axis 150 save_path: Optional path to save the figure 151 152 Returns: 153 The matplotlib Figure object 154 """ 155 # Compute feature densities 156 densities = (self.features > 0).mean(axis=0) 157 158 fig, ax = plt.subplots(figsize=(10, 6)) 159 160 # Remove zero-density features (dead neurons) 161 active_densities = densities[densities > 0] 162 163 # Plot histogram on log scale 164 if log_scale: 165 bins_log = np.logspace(-8, 0, bins) 166 ax.hist(active_densities, bins=bins_log, edgecolor="black", alpha=0.7) 167 ax.set_xscale("log") 168 else: 169 ax.hist(active_densities, bins=bins, edgecolor="black", alpha=0.7) 170 171 ax.set_xlabel("Feature Density (fraction of examples)", fontsize=12) 172 ax.set_ylabel("Number of Features", fontsize=12) 173 ax.set_title("Distribution of Feature Densities", fontsize=14) 174 175 # Add statistics 176 n_dead = (densities == 0).sum() 177 n_active = len(active_densities) 178 179 stats_text = f"Dead features: {n_dead} ({n_dead / len(densities) * 100:.1f}%)\n" 180 stats_text += ( 181 f"Active features: {n_active} ({n_active / len(densities) * 100:.1f}%)\n" 182 ) 183 stats_text += f"Median density: {np.median(active_densities):.2e}" 184 185 ax.text( 186 0.95, 187 0.95, 188 stats_text, 189 transform=ax.transAxes, 190 verticalalignment="top", 191 horizontalalignment="right", 192 bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), 193 ) 194 195 plt.tight_layout() 196 197 if save_path or self.output_dir: 198 path = save_path or self.output_dir / "feature_density.png" 199 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 200 logger.info(f"Saved feature density plot to {path}") 201 202 # Log to wandb 203 self._log_figure_to_wandb(fig, "feature_density") 204 205 return fig 206 207 def plot_activation_histogram( 208 self, 209 feature_idx: int, 210 bins: int = 100, 211 log_y: bool = True, 212 save_path: Optional[str] = None, 213 ) -> plt.Figure: 214 """ 215 Plot activation histogram for a specific feature. 216 217 Args: 218 feature_idx: Index of the feature to visualize 219 bins: Number of histogram bins 220 log_y: Whether to use log scale on y-axis 221 save_path: Optional path to save the figure 222 223 Returns: 224 The matplotlib Figure object 225 """ 226 feature_acts = self.features[:, feature_idx] 227 228 fig, ax = plt.subplots(figsize=(10, 6)) 229 230 # Plot histogram 231 counts, bin_edges, patches = ax.hist( 232 feature_acts[feature_acts > 0], # Only non-zero activations 233 bins=bins, 234 edgecolor="black", 235 alpha=0.7, 236 ) 237 238 if log_y: 239 ax.set_yscale("log") 240 241 ax.set_xlabel("Activation Value", fontsize=12) 242 ax.set_ylabel("Count (log scale)" if log_y else "Count", fontsize=12) 243 ax.set_title(f"Activation Histogram for Feature {feature_idx}", fontsize=14) 244 245 # Add statistics 246 density = (feature_acts > 0).mean() 247 max_act = feature_acts.max() 248 mean_act = feature_acts[feature_acts > 0].mean() if density > 0 else 0 249 250 stats_text = f"Density: {density:.2e}\n" 251 stats_text += f"Max activation: {max_act:.4f}\n" 252 stats_text += f"Mean (active): {mean_act:.4f}" 253 254 ax.text( 255 0.95, 256 0.95, 257 stats_text, 258 transform=ax.transAxes, 259 verticalalignment="top", 260 horizontalalignment="right", 261 bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.5), 262 ) 263 264 plt.tight_layout() 265 266 if save_path or self.output_dir: 267 path = ( 268 save_path 269 or self.output_dir / f"activation_histogram_feature_{feature_idx}.png" 270 ) 271 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 272 logger.info(f"Saved activation histogram to {path}") 273 274 return fig 275 276 def plot_training_curves( 277 self, 278 save_path: Optional[str] = None, 279 ) -> plt.Figure: 280 """ 281 Plot training loss and L0 norm curves. 282 283 Args: 284 save_path: Optional path to save the figure 285 286 Returns: 287 The matplotlib Figure object 288 """ 289 if not self.sae.training_losses: 290 logger.warning("No training data available") 291 return None 292 293 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) 294 295 epochs = range(1, len(self.sae.training_losses) + 1) 296 297 # Loss curve 298 ax1.plot(epochs, self.sae.training_losses, "b-", linewidth=2) 299 ax1.set_xlabel("Epoch", fontsize=12) 300 ax1.set_ylabel("Loss", fontsize=12) 301 ax1.set_title("Training Loss", fontsize=14) 302 ax1.grid(True, alpha=0.3) 303 304 # L0 norm curve 305 ax2.plot(epochs, self.sae.training_l0_norms, "r-", linewidth=2) 306 ax2.set_xlabel("Epoch", fontsize=12) 307 ax2.set_ylabel("L0 Norm (avg # active features)", fontsize=12) 308 ax2.set_title("Sparsity (L0 Norm)", fontsize=14) 309 ax2.grid(True, alpha=0.3) 310 311 plt.tight_layout() 312 313 if save_path or self.output_dir: 314 path = save_path or self.output_dir / "training_curves.png" 315 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 316 logger.info(f"Saved training curves to {path}") 317 318 return fig 319 320 def plot_top_features( 321 self, 322 n_features: int = 10, 323 by: str = "density", 324 save_path: Optional[str] = None, 325 ) -> plt.Figure: 326 """ 327 Plot top N features by various metrics. 328 329 Args: 330 n_features: Number of top features to show 331 by: Metric to sort by ("density", "max_activation", "mean_activation") 332 save_path: Optional path to save the figure 333 334 Returns: 335 The matplotlib Figure object 336 """ 337 fig, ax = plt.subplots(figsize=(12, 6)) 338 339 # Compute metric for each feature 340 if by == "density": 341 values = (self.features > 0).mean(axis=0) 342 ylabel = "Feature Density" 343 title = f"Top {n_features} Features by Density" 344 elif by == "max_activation": 345 values = self.features.max(axis=0) 346 ylabel = "Max Activation" 347 title = f"Top {n_features} Features by Max Activation" 348 elif by == "mean_activation": 349 values = self.features.mean(axis=0) 350 ylabel = "Mean Activation" 351 title = f"Top {n_features} Features by Mean Activation" 352 else: 353 raise ValueError(f"Unknown metric: {by}") 354 355 # Get top features 356 top_indices = np.argsort(values)[-n_features:][::-1] 357 top_values = values[top_indices] 358 359 # Plot bar chart 360 _ = ax.barh(range(n_features), top_values[::-1]) 361 ax.set_yticks(range(n_features)) 362 ax.set_yticklabels([f"Feature {i}" for i in top_indices[::-1]]) 363 ax.set_xlabel(ylabel, fontsize=12) 364 ax.set_title(title, fontsize=14) 365 ax.grid(True, axis="x", alpha=0.3) 366 367 plt.tight_layout() 368 369 if save_path or self.output_dir: 370 path = save_path or self.output_dir / f"top_features_{by}.png" 371 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 372 logger.info(f"Saved top features plot to {path}") 373 374 return fig 375 376 def plot_feature_examples( 377 self, 378 feature_idx: int, 379 k: int = 10, 380 show_text: bool = True, 381 max_text_length: int = 200, 382 save_path: Optional[str] = None, 383 ) -> plt.Figure: 384 """ 385 Plot top k activating examples for a specific feature. 386 387 Args: 388 feature_idx: Index of the feature to visualize 389 k: Number of examples to show 390 show_text: Whether to show the text examples 391 max_text_length: Maximum length of text to display 392 save_path: Optional path to save the figure 393 394 Returns: 395 The matplotlib Figure object 396 """ 397 # Get top activating examples 398 top_values, top_indices = self.sae.get_top_activating_examples( 399 self.activations, feature_idx, k 400 ) 401 402 fig, ax = plt.subplots(figsize=(12, k * 0.5)) 403 404 # Plot bar chart 405 _ = ax.barh(range(k), top_values[::-1]) 406 ax.set_yticks(range(k)) 407 ax.set_xlabel("Activation Value", fontsize=12) 408 ax.set_title( 409 f"Top {k} Activating Examples for Feature {feature_idx}", fontsize=14 410 ) 411 ax.grid(True, axis="x", alpha=0.3) 412 413 # Add text examples if available 414 if show_text and "samples_metadata" in self.metadata: 415 samples_metadata = self.metadata["samples_metadata"] 416 417 labels = [] 418 for i, idx in enumerate(top_indices[::-1]): 419 if idx < len(samples_metadata): 420 text = samples_metadata[idx].get("text", "") 421 # Clean up text 422 text = text.replace("\n", " ").strip() 423 if len(text) > max_text_length: 424 text = text[:max_text_length] + "..." 425 labels.append(f"{i + 1}. {text}") 426 else: 427 labels.append(f"{i + 1}. (No metadata)") 428 429 ax.set_yticklabels(labels, fontsize=9) 430 431 plt.tight_layout() 432 433 if save_path or self.output_dir: 434 path = save_path or self.output_dir / f"feature_{feature_idx}_examples.png" 435 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 436 logger.info(f"Saved feature examples to {path}") 437 438 return fig 439 440 def plot_decoder_weights( 441 self, 442 feature_indices: Optional[List[int]] = None, 443 n_features: int = 10, 444 save_path: Optional[str] = None, 445 ) -> plt.Figure: 446 """ 447 Visualize decoder weights for specific features. 448 449 This shows how each feature maps back to the MLP activation space. 450 451 Args: 452 feature_indices: Specific feature indices to plot 453 n_features: Number of features to plot (if feature_indices is None) 454 save_path: Optional path to save the figure 455 456 Returns: 457 The matplotlib Figure object 458 """ 459 import torch 460 461 if feature_indices is None: 462 # Get features by density 463 densities = (self.features > 0).mean(axis=0) 464 feature_indices = np.argsort(densities)[-n_features:][::-1] 465 466 n_features_plot = len(feature_indices) 467 fig, axes = plt.subplots(1, n_features_plot, figsize=(3 * n_features_plot, 4)) 468 469 if n_features_plot == 1: 470 axes = [axes] 471 472 with torch.no_grad(): 473 for i, feat_idx in enumerate(feature_indices): 474 decoder_weight = self.sae.decoder.weight[:, feat_idx].cpu().numpy() 475 476 axes[i].hist(decoder_weight, bins=50, edgecolor="black", alpha=0.7) 477 axes[i].set_title(f"Feature {feat_idx}", fontsize=12) 478 axes[i].set_xlabel("Weight Value", fontsize=10) 479 axes[i].set_ylabel("Count", fontsize=10) 480 axes[i].grid(True, alpha=0.3) 481 482 plt.suptitle("Decoder Weight Distributions", fontsize=14) 483 plt.tight_layout() 484 485 if save_path or self.output_dir: 486 path = save_path or self.output_dir / "decoder_weights.png" 487 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 488 logger.info(f"Saved decoder weights plot to {path}") 489 490 return fig 491 492 def create_feature_dashboard( 493 self, 494 feature_idx: int, 495 save_path: Optional[str] = None, 496 ) -> plt.Figure: 497 """ 498 Create a comprehensive dashboard for a single feature. 499 500 Includes activation histogram, top examples, and decoder weights. 501 502 Args: 503 feature_idx: Index of the feature to visualize 504 save_path: Optional path to save the figure 505 506 Returns: 507 The matplotlib Figure object 508 """ 509 fig = plt.figure(figsize=(16, 10)) 510 gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3) 511 512 # 1. Activation histogram 513 ax1 = fig.add_subplot(gs[0, 0]) 514 feature_acts = self.features[:, feature_idx] 515 ax1.hist(feature_acts[feature_acts > 0], bins=50, edgecolor="black", alpha=0.7) 516 ax1.set_xlabel("Activation Value", fontsize=11) 517 ax1.set_ylabel("Count", fontsize=11) 518 ax1.set_title(f"Activation Distribution (Feature {feature_idx})", fontsize=12) 519 ax1.set_yscale("log") 520 521 # 2. Top activating examples 522 ax2 = fig.add_subplot(gs[0, 1]) 523 top_values, top_indices = self.sae.get_top_activating_examples( 524 self.activations, feature_idx, 10 525 ) 526 ax2.barh(range(10), top_values[::-1]) 527 ax2.set_xlabel("Activation Value", fontsize=11) 528 ax2.set_title("Top 10 Activating Examples", fontsize=12) 529 ax2.set_yticks(range(10)) 530 ax2.grid(True, axis="x", alpha=0.3) 531 532 # 3. Decoder weights 533 ax3 = fig.add_subplot(gs[1, :]) 534 import torch 535 536 with torch.no_grad(): 537 decoder_weight = self.sae.decoder.weight[:, feature_idx].cpu().numpy() 538 539 ax3.hist(decoder_weight, bins=100, edgecolor="black", alpha=0.7) 540 ax3.set_xlabel("Weight Value", fontsize=11) 541 ax3.set_ylabel("Count", fontsize=11) 542 ax3.set_title("Decoder Weight Distribution", fontsize=12) 543 ax3.grid(True, alpha=0.3) 544 545 # Add statistics box 546 density = (feature_acts > 0).mean() 547 max_act = feature_acts.max() 548 549 stats_text = f"Feature {feature_idx} Statistics:\n" 550 stats_text += f"Density: {density:.2e}\n" 551 stats_text += f"Max Activation: {max_act:.4f}\n" 552 stats_text += f"Decoder Norm: {np.linalg.norm(decoder_weight):.4f}" 553 554 fig.text( 555 0.5, 556 0.02, 557 stats_text, 558 ha="center", 559 fontsize=12, 560 bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), 561 ) 562 563 plt.suptitle(f"Feature {feature_idx} Dashboard", fontsize=16) 564 565 if save_path or self.output_dir: 566 path = save_path or self.output_dir / f"feature_{feature_idx}_dashboard.png" 567 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 568 logger.info(f"Saved feature dashboard to {path}") 569 570 return fig 571 572 def save_all(self, n_features: int = 10) -> None: 573 """ 574 Generate and save all standard visualizations. 575 576 Args: 577 n_features: Number of features to include in detailed plots 578 """ 579 logger.info("Generating all visualizations...") 580 581 self.plot_feature_density() 582 self.plot_training_curves() 583 584 for metric in ["density", "max_activation"]: 585 self.plot_top_features(n_features=n_features, by=metric) 586 587 # Create dashboards for top features by density 588 densities = (self.features > 0).mean(axis=0) 589 top_features = np.argsort(densities)[-n_features:] 590 591 for feat_idx in top_features: 592 self.create_feature_dashboard(feat_idx) 593 594 logger.info(f"All visualizations saved to {self.output_dir}")
Visualize sparse autoencoder features and activations.
This class provides methods to create various plots that help understand the learned features, similar to the visualizations in the "Towards Monosemanticity" paper.
Example:
visualizer = FeatureVisualizer(sae, activations, metadata) visualizer.plot_feature_density() visualizer.plot_top_features(n_features=10) visualizer.plot_feature_examples(feature_idx=0)
58 def __init__( 59 self, 60 sae: SparseAutoencoder, 61 activations: np.ndarray, 62 metadata: Optional[Dict[str, Any]] = None, 63 output_dir: Union[str, Path] = "./visualizations", 64 style: str = "whitegrid", 65 dpi: int = 150, 66 wandb_config: Optional[WandbConfig] = None, 67 log_to_wandb: bool = True, 68 ): 69 """ 70 Initialize the visualizer. 71 72 Args: 73 sae: Trained sparse autoencoder 74 activations: The MLP activations used to train the SAE 75 metadata: Optional metadata dictionary (e.g., from ActivationExtractor) 76 output_dir: Directory to save plots 77 style: Seaborn style to use 78 dpi: DPI for saved figures 79 wandb_config: Optional WandbConfig for logging visualizations 80 log_to_wandb: If True, log plots to wandb when wandb_config is provided 81 """ 82 self.sae = sae 83 self.activations = activations 84 self.metadata = metadata or {} 85 self.output_dir = Path(output_dir) 86 self.dpi = dpi 87 self.wandb_config = wandb_config 88 self.log_to_wandb = log_to_wandb 89 90 # Set style 91 sns.set_style(style) 92 sns.set_palette("husl") 93 94 # Create output directory 95 self.output_dir.mkdir(parents=True, exist_ok=True) 96 97 # Compute feature activations 98 self._compute_features()
Initialize the visualizer.
Arguments:
- sae: Trained sparse autoencoder
- activations: The MLP activations used to train the SAE
- metadata: Optional metadata dictionary (e.g., from ActivationExtractor)
- output_dir: Directory to save plots
- style: Seaborn style to use
- dpi: DPI for saved figures
- wandb_config: Optional WandbConfig for logging visualizations
- log_to_wandb: If True, log plots to wandb when wandb_config is provided
135 def plot_feature_density( 136 self, 137 bins: int = 50, 138 log_scale: bool = True, 139 save_path: Optional[str] = None, 140 ) -> plt.Figure: 141 """ 142 Plot the distribution of feature densities. 143 144 Feature density is the fraction of examples on which a feature fires. 145 This histogram helps understand the sparsity distribution. 146 147 Args: 148 bins: Number of histogram bins 149 log_scale: Whether to use log scale on x-axis 150 save_path: Optional path to save the figure 151 152 Returns: 153 The matplotlib Figure object 154 """ 155 # Compute feature densities 156 densities = (self.features > 0).mean(axis=0) 157 158 fig, ax = plt.subplots(figsize=(10, 6)) 159 160 # Remove zero-density features (dead neurons) 161 active_densities = densities[densities > 0] 162 163 # Plot histogram on log scale 164 if log_scale: 165 bins_log = np.logspace(-8, 0, bins) 166 ax.hist(active_densities, bins=bins_log, edgecolor="black", alpha=0.7) 167 ax.set_xscale("log") 168 else: 169 ax.hist(active_densities, bins=bins, edgecolor="black", alpha=0.7) 170 171 ax.set_xlabel("Feature Density (fraction of examples)", fontsize=12) 172 ax.set_ylabel("Number of Features", fontsize=12) 173 ax.set_title("Distribution of Feature Densities", fontsize=14) 174 175 # Add statistics 176 n_dead = (densities == 0).sum() 177 n_active = len(active_densities) 178 179 stats_text = f"Dead features: {n_dead} ({n_dead / len(densities) * 100:.1f}%)\n" 180 stats_text += ( 181 f"Active features: {n_active} ({n_active / len(densities) * 100:.1f}%)\n" 182 ) 183 stats_text += f"Median density: {np.median(active_densities):.2e}" 184 185 ax.text( 186 0.95, 187 0.95, 188 stats_text, 189 transform=ax.transAxes, 190 verticalalignment="top", 191 horizontalalignment="right", 192 bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), 193 ) 194 195 plt.tight_layout() 196 197 if save_path or self.output_dir: 198 path = save_path or self.output_dir / "feature_density.png" 199 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 200 logger.info(f"Saved feature density plot to {path}") 201 202 # Log to wandb 203 self._log_figure_to_wandb(fig, "feature_density") 204 205 return fig
Plot the distribution of feature densities.
Feature density is the fraction of examples on which a feature fires. This histogram helps understand the sparsity distribution.
Arguments:
- bins: Number of histogram bins
- log_scale: Whether to use log scale on x-axis
- save_path: Optional path to save the figure
Returns:
The matplotlib Figure object
207 def plot_activation_histogram( 208 self, 209 feature_idx: int, 210 bins: int = 100, 211 log_y: bool = True, 212 save_path: Optional[str] = None, 213 ) -> plt.Figure: 214 """ 215 Plot activation histogram for a specific feature. 216 217 Args: 218 feature_idx: Index of the feature to visualize 219 bins: Number of histogram bins 220 log_y: Whether to use log scale on y-axis 221 save_path: Optional path to save the figure 222 223 Returns: 224 The matplotlib Figure object 225 """ 226 feature_acts = self.features[:, feature_idx] 227 228 fig, ax = plt.subplots(figsize=(10, 6)) 229 230 # Plot histogram 231 counts, bin_edges, patches = ax.hist( 232 feature_acts[feature_acts > 0], # Only non-zero activations 233 bins=bins, 234 edgecolor="black", 235 alpha=0.7, 236 ) 237 238 if log_y: 239 ax.set_yscale("log") 240 241 ax.set_xlabel("Activation Value", fontsize=12) 242 ax.set_ylabel("Count (log scale)" if log_y else "Count", fontsize=12) 243 ax.set_title(f"Activation Histogram for Feature {feature_idx}", fontsize=14) 244 245 # Add statistics 246 density = (feature_acts > 0).mean() 247 max_act = feature_acts.max() 248 mean_act = feature_acts[feature_acts > 0].mean() if density > 0 else 0 249 250 stats_text = f"Density: {density:.2e}\n" 251 stats_text += f"Max activation: {max_act:.4f}\n" 252 stats_text += f"Mean (active): {mean_act:.4f}" 253 254 ax.text( 255 0.95, 256 0.95, 257 stats_text, 258 transform=ax.transAxes, 259 verticalalignment="top", 260 horizontalalignment="right", 261 bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.5), 262 ) 263 264 plt.tight_layout() 265 266 if save_path or self.output_dir: 267 path = ( 268 save_path 269 or self.output_dir / f"activation_histogram_feature_{feature_idx}.png" 270 ) 271 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 272 logger.info(f"Saved activation histogram to {path}") 273 274 return fig
Plot activation histogram for a specific feature.
Arguments:
- feature_idx: Index of the feature to visualize
- bins: Number of histogram bins
- log_y: Whether to use log scale on y-axis
- save_path: Optional path to save the figure
Returns:
The matplotlib Figure object
276 def plot_training_curves( 277 self, 278 save_path: Optional[str] = None, 279 ) -> plt.Figure: 280 """ 281 Plot training loss and L0 norm curves. 282 283 Args: 284 save_path: Optional path to save the figure 285 286 Returns: 287 The matplotlib Figure object 288 """ 289 if not self.sae.training_losses: 290 logger.warning("No training data available") 291 return None 292 293 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) 294 295 epochs = range(1, len(self.sae.training_losses) + 1) 296 297 # Loss curve 298 ax1.plot(epochs, self.sae.training_losses, "b-", linewidth=2) 299 ax1.set_xlabel("Epoch", fontsize=12) 300 ax1.set_ylabel("Loss", fontsize=12) 301 ax1.set_title("Training Loss", fontsize=14) 302 ax1.grid(True, alpha=0.3) 303 304 # L0 norm curve 305 ax2.plot(epochs, self.sae.training_l0_norms, "r-", linewidth=2) 306 ax2.set_xlabel("Epoch", fontsize=12) 307 ax2.set_ylabel("L0 Norm (avg # active features)", fontsize=12) 308 ax2.set_title("Sparsity (L0 Norm)", fontsize=14) 309 ax2.grid(True, alpha=0.3) 310 311 plt.tight_layout() 312 313 if save_path or self.output_dir: 314 path = save_path or self.output_dir / "training_curves.png" 315 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 316 logger.info(f"Saved training curves to {path}") 317 318 return fig
Plot training loss and L0 norm curves.
Arguments:
- save_path: Optional path to save the figure
Returns:
The matplotlib Figure object
320 def plot_top_features( 321 self, 322 n_features: int = 10, 323 by: str = "density", 324 save_path: Optional[str] = None, 325 ) -> plt.Figure: 326 """ 327 Plot top N features by various metrics. 328 329 Args: 330 n_features: Number of top features to show 331 by: Metric to sort by ("density", "max_activation", "mean_activation") 332 save_path: Optional path to save the figure 333 334 Returns: 335 The matplotlib Figure object 336 """ 337 fig, ax = plt.subplots(figsize=(12, 6)) 338 339 # Compute metric for each feature 340 if by == "density": 341 values = (self.features > 0).mean(axis=0) 342 ylabel = "Feature Density" 343 title = f"Top {n_features} Features by Density" 344 elif by == "max_activation": 345 values = self.features.max(axis=0) 346 ylabel = "Max Activation" 347 title = f"Top {n_features} Features by Max Activation" 348 elif by == "mean_activation": 349 values = self.features.mean(axis=0) 350 ylabel = "Mean Activation" 351 title = f"Top {n_features} Features by Mean Activation" 352 else: 353 raise ValueError(f"Unknown metric: {by}") 354 355 # Get top features 356 top_indices = np.argsort(values)[-n_features:][::-1] 357 top_values = values[top_indices] 358 359 # Plot bar chart 360 _ = ax.barh(range(n_features), top_values[::-1]) 361 ax.set_yticks(range(n_features)) 362 ax.set_yticklabels([f"Feature {i}" for i in top_indices[::-1]]) 363 ax.set_xlabel(ylabel, fontsize=12) 364 ax.set_title(title, fontsize=14) 365 ax.grid(True, axis="x", alpha=0.3) 366 367 plt.tight_layout() 368 369 if save_path or self.output_dir: 370 path = save_path or self.output_dir / f"top_features_{by}.png" 371 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 372 logger.info(f"Saved top features plot to {path}") 373 374 return fig
Plot top N features by various metrics.
Arguments:
- n_features: Number of top features to show
- by: Metric to sort by ("density", "max_activation", "mean_activation")
- save_path: Optional path to save the figure
Returns:
The matplotlib Figure object
376 def plot_feature_examples( 377 self, 378 feature_idx: int, 379 k: int = 10, 380 show_text: bool = True, 381 max_text_length: int = 200, 382 save_path: Optional[str] = None, 383 ) -> plt.Figure: 384 """ 385 Plot top k activating examples for a specific feature. 386 387 Args: 388 feature_idx: Index of the feature to visualize 389 k: Number of examples to show 390 show_text: Whether to show the text examples 391 max_text_length: Maximum length of text to display 392 save_path: Optional path to save the figure 393 394 Returns: 395 The matplotlib Figure object 396 """ 397 # Get top activating examples 398 top_values, top_indices = self.sae.get_top_activating_examples( 399 self.activations, feature_idx, k 400 ) 401 402 fig, ax = plt.subplots(figsize=(12, k * 0.5)) 403 404 # Plot bar chart 405 _ = ax.barh(range(k), top_values[::-1]) 406 ax.set_yticks(range(k)) 407 ax.set_xlabel("Activation Value", fontsize=12) 408 ax.set_title( 409 f"Top {k} Activating Examples for Feature {feature_idx}", fontsize=14 410 ) 411 ax.grid(True, axis="x", alpha=0.3) 412 413 # Add text examples if available 414 if show_text and "samples_metadata" in self.metadata: 415 samples_metadata = self.metadata["samples_metadata"] 416 417 labels = [] 418 for i, idx in enumerate(top_indices[::-1]): 419 if idx < len(samples_metadata): 420 text = samples_metadata[idx].get("text", "") 421 # Clean up text 422 text = text.replace("\n", " ").strip() 423 if len(text) > max_text_length: 424 text = text[:max_text_length] + "..." 425 labels.append(f"{i + 1}. {text}") 426 else: 427 labels.append(f"{i + 1}. (No metadata)") 428 429 ax.set_yticklabels(labels, fontsize=9) 430 431 plt.tight_layout() 432 433 if save_path or self.output_dir: 434 path = save_path or self.output_dir / f"feature_{feature_idx}_examples.png" 435 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 436 logger.info(f"Saved feature examples to {path}") 437 438 return fig
Plot top k activating examples for a specific feature.
Arguments:
- feature_idx: Index of the feature to visualize
- k: Number of examples to show
- show_text: Whether to show the text examples
- max_text_length: Maximum length of text to display
- save_path: Optional path to save the figure
Returns:
The matplotlib Figure object
440 def plot_decoder_weights( 441 self, 442 feature_indices: Optional[List[int]] = None, 443 n_features: int = 10, 444 save_path: Optional[str] = None, 445 ) -> plt.Figure: 446 """ 447 Visualize decoder weights for specific features. 448 449 This shows how each feature maps back to the MLP activation space. 450 451 Args: 452 feature_indices: Specific feature indices to plot 453 n_features: Number of features to plot (if feature_indices is None) 454 save_path: Optional path to save the figure 455 456 Returns: 457 The matplotlib Figure object 458 """ 459 import torch 460 461 if feature_indices is None: 462 # Get features by density 463 densities = (self.features > 0).mean(axis=0) 464 feature_indices = np.argsort(densities)[-n_features:][::-1] 465 466 n_features_plot = len(feature_indices) 467 fig, axes = plt.subplots(1, n_features_plot, figsize=(3 * n_features_plot, 4)) 468 469 if n_features_plot == 1: 470 axes = [axes] 471 472 with torch.no_grad(): 473 for i, feat_idx in enumerate(feature_indices): 474 decoder_weight = self.sae.decoder.weight[:, feat_idx].cpu().numpy() 475 476 axes[i].hist(decoder_weight, bins=50, edgecolor="black", alpha=0.7) 477 axes[i].set_title(f"Feature {feat_idx}", fontsize=12) 478 axes[i].set_xlabel("Weight Value", fontsize=10) 479 axes[i].set_ylabel("Count", fontsize=10) 480 axes[i].grid(True, alpha=0.3) 481 482 plt.suptitle("Decoder Weight Distributions", fontsize=14) 483 plt.tight_layout() 484 485 if save_path or self.output_dir: 486 path = save_path or self.output_dir / "decoder_weights.png" 487 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 488 logger.info(f"Saved decoder weights plot to {path}") 489 490 return fig
Visualize decoder weights for specific features.
This shows how each feature maps back to the MLP activation space.
Arguments:
- feature_indices: Specific feature indices to plot
- n_features: Number of features to plot (if feature_indices is None)
- save_path: Optional path to save the figure
Returns:
The matplotlib Figure object
492 def create_feature_dashboard( 493 self, 494 feature_idx: int, 495 save_path: Optional[str] = None, 496 ) -> plt.Figure: 497 """ 498 Create a comprehensive dashboard for a single feature. 499 500 Includes activation histogram, top examples, and decoder weights. 501 502 Args: 503 feature_idx: Index of the feature to visualize 504 save_path: Optional path to save the figure 505 506 Returns: 507 The matplotlib Figure object 508 """ 509 fig = plt.figure(figsize=(16, 10)) 510 gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3) 511 512 # 1. Activation histogram 513 ax1 = fig.add_subplot(gs[0, 0]) 514 feature_acts = self.features[:, feature_idx] 515 ax1.hist(feature_acts[feature_acts > 0], bins=50, edgecolor="black", alpha=0.7) 516 ax1.set_xlabel("Activation Value", fontsize=11) 517 ax1.set_ylabel("Count", fontsize=11) 518 ax1.set_title(f"Activation Distribution (Feature {feature_idx})", fontsize=12) 519 ax1.set_yscale("log") 520 521 # 2. Top activating examples 522 ax2 = fig.add_subplot(gs[0, 1]) 523 top_values, top_indices = self.sae.get_top_activating_examples( 524 self.activations, feature_idx, 10 525 ) 526 ax2.barh(range(10), top_values[::-1]) 527 ax2.set_xlabel("Activation Value", fontsize=11) 528 ax2.set_title("Top 10 Activating Examples", fontsize=12) 529 ax2.set_yticks(range(10)) 530 ax2.grid(True, axis="x", alpha=0.3) 531 532 # 3. Decoder weights 533 ax3 = fig.add_subplot(gs[1, :]) 534 import torch 535 536 with torch.no_grad(): 537 decoder_weight = self.sae.decoder.weight[:, feature_idx].cpu().numpy() 538 539 ax3.hist(decoder_weight, bins=100, edgecolor="black", alpha=0.7) 540 ax3.set_xlabel("Weight Value", fontsize=11) 541 ax3.set_ylabel("Count", fontsize=11) 542 ax3.set_title("Decoder Weight Distribution", fontsize=12) 543 ax3.grid(True, alpha=0.3) 544 545 # Add statistics box 546 density = (feature_acts > 0).mean() 547 max_act = feature_acts.max() 548 549 stats_text = f"Feature {feature_idx} Statistics:\n" 550 stats_text += f"Density: {density:.2e}\n" 551 stats_text += f"Max Activation: {max_act:.4f}\n" 552 stats_text += f"Decoder Norm: {np.linalg.norm(decoder_weight):.4f}" 553 554 fig.text( 555 0.5, 556 0.02, 557 stats_text, 558 ha="center", 559 fontsize=12, 560 bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), 561 ) 562 563 plt.suptitle(f"Feature {feature_idx} Dashboard", fontsize=16) 564 565 if save_path or self.output_dir: 566 path = save_path or self.output_dir / f"feature_{feature_idx}_dashboard.png" 567 plt.savefig(path, dpi=self.dpi, bbox_inches="tight") 568 logger.info(f"Saved feature dashboard to {path}") 569 570 return fig
Create a comprehensive dashboard for a single feature.
Includes activation histogram, top examples, and decoder weights.
Arguments:
- feature_idx: Index of the feature to visualize
- save_path: Optional path to save the figure
Returns:
The matplotlib Figure object
572 def save_all(self, n_features: int = 10) -> None: 573 """ 574 Generate and save all standard visualizations. 575 576 Args: 577 n_features: Number of features to include in detailed plots 578 """ 579 logger.info("Generating all visualizations...") 580 581 self.plot_feature_density() 582 self.plot_training_curves() 583 584 for metric in ["density", "max_activation"]: 585 self.plot_top_features(n_features=n_features, by=metric) 586 587 # Create dashboards for top features by density 588 densities = (self.features > 0).mean(axis=0) 589 top_features = np.argsort(densities)[-n_features:] 590 591 for feat_idx in top_features: 592 self.create_feature_dashboard(feat_idx) 593 594 logger.info(f"All visualizations saved to {self.output_dir}")
Generate and save all standard visualizations.
Arguments:
- n_features: Number of features to include in detailed plots
135class SAESteering: 136 """ 137 Steer language model generation using trained SAE features. 138 139 Wraps a HuggingFace language model (via nnsight) and applies SAE 140 feature interventions during generation. At each generation step, 141 a forward hook on the target MLP module adds a scaled decoder 142 weight vector, biasing the residual stream toward the chosen 143 feature direction. 144 145 The HuggingFace auth token is read automatically from 146 ``HUGGINGFACE_HUB_TOKEN`` (via :func:`~drrik.settings.get_settings`) 147 when the *token* parameter is not explicitly provided. 148 149 Attributes: 150 sae: The trained :class:`~drrik.autoencoder.SparseAutoencoder` 151 providing feature directions. 152 model: The nnsight :class:`~nnsight.LanguageModel` wrapper. 153 tokenizer: The HuggingFace tokenizer for the model. 154 target_layer: The layer index where interventions are applied. 155 device: The device the SAE parameters live on. 156 157 Example: 158 ```python 159 sae = SparseAutoencoder.load("sae_model.pt") 160 steering = SAESteering(sae, model_name="google/gemma-2b", layer=5) 161 162 # Steered generation 163 result = steering.generate( 164 "The sky is", 165 feature_idx=128, 166 strength=2.5, 167 max_new_tokens=50, 168 ) 169 170 # Baseline (no steering) 171 baseline = steering.generate("The sky is", max_new_tokens=50) 172 173 # Multi-feature steering 174 result = steering.generate( 175 "The sky is", 176 feature_indices=[10, 128], 177 strengths=[1.0, 2.5], 178 ) 179 ``` 180 """ 181 182 def __init__( 183 self, 184 sae: SparseAutoencoder, 185 model_name: str, 186 layer: int, 187 revision: str = "main", 188 torch_dtype: str = "float16", 189 device_map: str = "auto", 190 trust_remote_code: bool = True, 191 token: Optional[str] = None, 192 ): 193 """ 194 Initialize the SAE steering controller. 195 196 If ``token`` is not provided, it is read from the environment via 197 ``get_settings()`` (i.e. ``HUGGINGFACE_HUB_TOKEN`` in ``.env``). 198 199 Args: 200 sae: A trained SparseAutoencoder whose decoder weights provide 201 steering directions. 202 model_name: HuggingFace model identifier for generation (e.g., 203 ``"google/gemma-2b"``). 204 layer: The transformer layer index where MLP activations are 205 intercepted and modified. 206 revision: Model revision to load. 207 torch_dtype: Weight dtype for model loading. 208 device_map: Device mapping strategy. 209 trust_remote_code: Whether to trust remote code from the repo. 210 token: HuggingFace token for gated models. Falls back to 211 ``HUGGINGFACE_HUB_TOKEN`` from ``.env`` when ``None``. 212 """ 213 if token is None: 214 settings = get_settings() 215 token = settings.huggingface_hub_token 216 217 self.sae = sae 218 self.target_layer = layer 219 self.device = next(sae.parameters()).device 220 221 logger.info(f"Loading model '{model_name}' for steering at layer {layer}") 222 223 dtype_map = { 224 "float16": torch.float16, 225 "bfloat16": torch.bfloat16, 226 "float32": torch.float32, 227 } 228 load_dtype = dtype_map.get(torch_dtype, torch.float16) 229 230 model_kwargs = { 231 "revision": revision, 232 "torch_dtype": load_dtype, 233 "device_map": device_map, 234 "trust_remote_code": trust_remote_code, 235 } 236 if token: 237 model_kwargs["token"] = token 238 239 self.model = LanguageModel(model_name, **model_kwargs) 240 self.tokenizer = AutoTokenizer.from_pretrained( 241 model_name, 242 revision=revision, 243 trust_remote_code=trust_remote_code, 244 token=token, 245 ) 246 247 if self.tokenizer.pad_token is None: 248 self.tokenizer.pad_token = self.tokenizer.eos_token 249 250 logger.info(f"Model loaded on {self.model.device}") 251 252 def _tokenize(self, text: str) -> tuple[torch.Tensor, Optional[torch.Tensor]]: 253 """Tokenize text and move tensors to the model device. 254 255 Centralises tokenization so that :meth:`generate` and 256 :meth:`_generate_baseline` don't duplicate this work. 257 258 Args: 259 text: Raw prompt string. 260 261 Returns: 262 A tuple of ``(input_ids, attention_mask)`` tensors on 263 the model's device. ``attention_mask`` may be ``None`` 264 if the tokenizer does not produce one. 265 """ 266 inputs = self.tokenizer( 267 text, return_tensors="pt", padding=True, truncation=True 268 ) 269 input_ids = inputs["input_ids"].to(self.model.device) 270 attention_mask = inputs.get("attention_mask", None) 271 if attention_mask is not None: 272 attention_mask = attention_mask.to(self.model.device) 273 return input_ids, attention_mask 274 275 def _generate_with_hooks( 276 self, 277 input_ids: torch.Tensor, 278 attention_mask: Optional[torch.Tensor], 279 steering_direction: torch.Tensor, 280 strength: float, 281 max_new_tokens: int, 282 temperature: float, 283 top_p: float, 284 ) -> str: 285 """ 286 Generate text using forward-hook-based MLP intervention. 287 288 Registers a forward hook on the target MLP module that adds 289 ``strength * steering_direction`` to the output at every 290 forward pass. The hook is removed in a ``finally`` block to 291 guarantee cleanup even on error. 292 293 Generation is performed token-by-token using the underlying 294 HuggingFace model (``self.model._module``) with top-p 295 (nucleus) sampling. 296 297 Args: 298 input_ids: Tokenized input tensor of shape ``(1, seq_len)``. 299 attention_mask: Optional attention mask tensor. 300 steering_direction: The combined SAE decoder weight vector 301 of shape ``(activation_dim,)``. 302 strength: Multiplicative steering magnitude (typically ``1.0`` 303 because strengths are baked into *steering_direction* by 304 the caller). 305 max_new_tokens: Maximum tokens to generate. 306 temperature: Sampling temperature. 307 top_p: Nucleus sampling threshold. 308 309 Returns: 310 The decoded text string (special tokens stripped). 311 """ 312 generated_tokens = input_ids.clone() 313 steering_direction = steering_direction.to(generated_tokens.device) 314 target_module = resolve_module_path( 315 self.model, f"model.layers[{self.target_layer}].mlp" 316 ) 317 hook_handle = None 318 319 try: 320 321 def steering_hook(module, input_args, output): 322 if isinstance(output, torch.Tensor): 323 output = output + strength * steering_direction 324 elif isinstance(output, tuple): 325 output = tuple( 326 o + strength * steering_direction 327 if isinstance(o, torch.Tensor) 328 else o 329 for o in output 330 ) 331 return output 332 333 hook_handle = target_module.register_forward_hook(steering_hook) 334 335 base_model = self.model._module 336 device = generated_tokens.device 337 338 for _ in tqdm(range(max_new_tokens), desc="Steering generation"): 339 if generated_tokens.shape[-1] >= self.tokenizer.model_max_length: 340 break 341 342 outputs = base_model( 343 input_ids=generated_tokens, 344 attention_mask=attention_mask, 345 use_cache=True, 346 ) 347 348 logits = outputs.logits[:, -1, :] 349 next_token = _sample_next_token(logits, temperature, top_p) 350 351 if next_token.item() == self.tokenizer.eos_token_id: 352 break 353 354 generated_tokens = torch.cat([generated_tokens, next_token], dim=-1) 355 if attention_mask is not None: 356 attention_mask = torch.cat( 357 [attention_mask, torch.ones(1, 1, device=device)], dim=-1 358 ) 359 360 finally: 361 if hook_handle is not None: 362 hook_handle.remove() 363 364 return self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) 365 366 def generate( 367 self, 368 text: str, 369 feature_idx: Optional[int] = None, 370 strength: float = 1.0, 371 max_new_tokens: int = 50, 372 temperature: float = 0.8, 373 top_p: float = 0.9, 374 feature_indices: Optional[List[int]] = None, 375 strengths: Optional[List[float]] = None, 376 ) -> str: 377 """ 378 Generate text with optional SAE feature steering. 379 380 This is the main entry point for generation. 381 382 - **No features** → calls :meth:`_generate_baseline` (plain 383 autoregressive generation). 384 - **Single feature** → pass ``feature_idx`` and ``strength``. 385 - **Multiple features** → pass ``feature_indices`` and 386 ``strengths`` as equal-length lists; the weighted decoder 387 directions are summed into a single combined vector. 388 389 Args: 390 text: The prompt text to generate from. 391 feature_idx: Index of a single SAE feature to steer with. 392 strength: Steering magnitude for a single feature 393 (default ``1.0``). 394 max_new_tokens: Maximum number of tokens to generate. 395 temperature: Sampling temperature. 396 top_p: Nucleus sampling threshold. 397 feature_indices: List of feature indices for multi-feature 398 steering. 399 strengths: List of steering magnitudes, one per entry in 400 ``feature_indices``. Defaults to ``[1.0, …]`` if 401 ``None``. 402 403 Returns: 404 The generated text string (special tokens stripped). 405 406 Raises: 407 ValueError: If ``feature_indices`` and ``strengths`` have 408 mismatched lengths, or a feature index is out of range. 409 410 Example: 411 ```python 412 # Single feature 413 steering.generate("Hello", feature_idx=42, strength=2.0) 414 415 # Multi-feature 416 steering.generate( 417 "Hello", 418 feature_indices=[10, 42], 419 strengths=[1.0, 3.0], 420 ) 421 422 # Baseline 423 steering.generate("Hello", max_new_tokens=100) 424 ``` 425 """ 426 if feature_idx is None and feature_indices is None: 427 return self._generate_baseline(text, max_new_tokens, temperature, top_p) 428 429 if feature_idx is not None: 430 feature_indices = [feature_idx] 431 strengths = [strength] 432 elif strengths is None: 433 strengths = [1.0] * len(feature_indices) 434 435 if len(feature_indices) != len(strengths): 436 raise ValueError("feature_indices and strengths must have the same length") 437 438 combined_direction = torch.zeros(self.sae.activation_dim, device=self.device) 439 for fid, s in zip(feature_indices, strengths): 440 if fid < 0 or fid >= self.sae.hidden_dim: 441 raise ValueError( 442 f"feature_idx {fid} out of range [0, {self.sae.hidden_dim})" 443 ) 444 combined_direction += s * self.sae.decoder.weight[:, fid] 445 446 input_ids, attention_mask = self._tokenize(text) 447 448 return self._generate_with_hooks( 449 input_ids, 450 attention_mask, 451 combined_direction, 452 1.0, 453 max_new_tokens, 454 temperature, 455 top_p, 456 ) 457 458 def _generate_baseline( 459 self, 460 text: str, 461 max_new_tokens: int, 462 temperature: float, 463 top_p: float, 464 ) -> str: 465 """ 466 Generate text without any SAE steering (baseline). 467 468 Runs the same token-by-token autoregressive loop as 469 :meth:`_generate_with_hooks` but without registering a 470 steering hook. 471 472 Args: 473 text: The prompt text. 474 max_new_tokens: Maximum tokens to generate. 475 temperature: Sampling temperature. 476 top_p: Nucleus sampling threshold. 477 478 Returns: 479 The generated text string (special tokens stripped). 480 """ 481 input_ids, attention_mask = self._tokenize(text) 482 483 base_model = self.model._module 484 generated_tokens = input_ids.clone() 485 486 with torch.no_grad(): 487 for _ in tqdm(range(max_new_tokens), desc="Baseline generation"): 488 if generated_tokens.shape[-1] >= self.tokenizer.model_max_length: 489 break 490 491 outputs = base_model( 492 input_ids=generated_tokens, 493 attention_mask=attention_mask, 494 use_cache=True, 495 ) 496 497 logits = outputs.logits[:, -1, :] 498 next_token = _sample_next_token(logits, temperature, top_p) 499 500 if next_token.item() == self.tokenizer.eos_token_id: 501 break 502 503 generated_tokens = torch.cat([generated_tokens, next_token], dim=-1) 504 if attention_mask is not None: 505 attention_mask = torch.cat( 506 [ 507 attention_mask, 508 torch.ones(1, 1, device=generated_tokens.device), 509 ], 510 dim=-1, 511 ) 512 513 return self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) 514 515 def compare_steering( 516 self, 517 text: str, 518 feature_idx: int, 519 strengths: Optional[List[float]] = None, 520 max_new_tokens: int = 50, 521 temperature: float = 0.8, 522 top_p: float = 0.9, 523 ) -> Dict[str, str]: 524 """ 525 Compare baseline generation against steered generations at different strengths. 526 527 Runs one generation per strength value (including ``0.0`` for a 528 no-steering baseline) and returns a dict mapping descriptive 529 labels to the generated text. 530 531 Args: 532 text: The prompt text. 533 feature_idx: SAE feature index to steer with. 534 strengths: List of strength values to test. Defaults to 535 ``[0.0, 1.0, 2.0, 3.0, 5.0]``. 536 max_new_tokens: Maximum tokens to generate per run. 537 temperature: Sampling temperature. 538 top_p: Nucleus sampling threshold. 539 540 Returns: 541 Dictionary mapping strength labels to generated text. 542 Keys are ``"baseline"`` (for strength ``0.0``) and 543 ``"strength_1.0"``, ``"strength_2.0"``, etc. 544 545 Example: 546 ```python 547 results = steering.compare_steering( 548 "The sky is", 549 feature_idx=42, 550 strengths=[0.0, 1.0, 3.0], 551 ) 552 for label, text in results.items(): 553 print(f"{label}: {text}") 554 ``` 555 """ 556 if strengths is None: 557 strengths = [0.0, 1.0, 2.0, 3.0, 5.0] 558 559 results = {} 560 561 for s in strengths: 562 if s == 0.0: 563 label = "baseline" 564 else: 565 label = f"strength_{s}" 566 567 logger.info(f"Generating with {label} (strength={s})") 568 generated = self.generate( 569 text, 570 feature_idx=feature_idx, 571 strength=s, 572 max_new_tokens=max_new_tokens, 573 temperature=temperature, 574 top_p=top_p, 575 ) 576 results[label] = generated 577 578 return results 579 580 def find_steering_features( 581 self, 582 text: str, 583 activations: np.ndarray, 584 top_k: int = 20, 585 min_activation: float = 0.0, 586 ) -> List[Dict[str, Any]]: 587 """ 588 Find the top-k SAE features that most activate on given activations. 589 590 Encodes the provided activations through the SAE and returns the 591 features with the highest mean activation values, sorted by 592 magnitude. Useful for identifying which features are relevant 593 to a given input before steering. 594 595 Args: 596 text: The prompt text (used for logging only). 597 activations: MLP activations array of shape 598 ``(n_samples, activation_dim)``. 599 top_k: Number of top features to return. 600 min_activation: Minimum mean activation threshold to 601 consider a feature. 602 603 Returns: 604 List of dicts sorted by descending activation, each 605 containing: 606 607 - ``feature_idx`` (``int``): The feature index. 608 - ``activation`` (``float``): Mean activation value. 609 - ``normalized_activation`` (``float``): Activation 610 divided by the decoder weight L2 norm. 611 - ``decoder_weight_norm`` (``float``): L2 norm of the 612 decoder weight vector. 613 """ 614 self.sae.eval() 615 device = self.device 616 617 with torch.no_grad(): 618 if isinstance(activations, np.ndarray): 619 acts_tensor = torch.from_numpy(activations).float().to(device) 620 else: 621 acts_tensor = activations.to(device) 622 623 features = self.sae.encode(acts_tensor) 624 625 avg_activations = features.mean(dim=0).cpu().numpy() 626 decoder_norms = self.sae.decoder.weight.norm(dim=0).cpu().numpy() 627 628 active_mask = avg_activations >= min_activation 629 active_indices = np.where(active_mask)[0] 630 631 sorted_indices = active_indices[ 632 np.argsort(avg_activations[active_indices])[::-1] 633 ][:top_k] 634 635 results = [] 636 for idx in sorted_indices: 637 results.append( 638 { 639 "feature_idx": int(idx), 640 "activation": float(avg_activations[idx]), 641 "normalized_activation": float( 642 avg_activations[idx] / (decoder_norms[idx] + 1e-8) 643 ), 644 "decoder_weight_norm": float(decoder_norms[idx]), 645 } 646 ) 647 648 logger.info(f"Found {len(results)} active features for text: {text[:100]}") 649 return results 650 651 def get_steering_direction( 652 self, 653 feature_idx: int, 654 normalize: bool = True, 655 ) -> np.ndarray: 656 """ 657 Get the steering direction vector for a given feature. 658 659 The steering direction is the SAE decoder weight column for the 660 specified feature. This vector represents the direction in MLP 661 activation space that the feature encodes. 662 663 Args: 664 feature_idx: Index of the SAE feature. 665 normalize: If ``True``, L2-normalise the direction vector. 666 667 Returns: 668 NumPy array of shape ``(activation_dim,)`` containing the 669 steering direction. 670 671 Raises: 672 ValueError: If ``feature_idx`` is out of range 673 ``[0, sae.hidden_dim)``. 674 """ 675 if feature_idx < 0 or feature_idx >= self.sae.hidden_dim: 676 raise ValueError( 677 f"feature_idx {feature_idx} out of range [0, {self.sae.hidden_dim})" 678 ) 679 680 direction = self.sae.decoder.weight[:, feature_idx].detach().cpu().numpy() 681 682 if normalize: 683 norm = np.linalg.norm(direction) 684 if norm > 0: 685 direction = direction / norm 686 687 return direction 688 689 def save_steering_analysis( 690 self, 691 results: Dict[str, str], 692 output_dir: Union[str, Path], 693 prompt: str = "", 694 feature_label: str = "", 695 ) -> Path: 696 """ 697 Save steering comparison results to disk. 698 699 Writes a text file containing the prompt and all generated 700 outputs (baseline and steered) for later review. The 701 filename includes a timestamp (and optional feature label) 702 to prevent overwriting previous analyses. 703 704 Args: 705 results: Dictionary mapping labels to generated text 706 (as returned by :meth:`compare_steering`). 707 output_dir: Directory to save the analysis file. 708 prompt: The original prompt text. 709 feature_label: Optional label included in the filename 710 (e.g. ``"feature_42"``). 711 712 Returns: 713 Path to the saved analysis file. 714 """ 715 output_dir = Path(output_dir) 716 output_dir.mkdir(parents=True, exist_ok=True) 717 718 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 719 suffix = f"_{feature_label}" if feature_label else "" 720 filepath = output_dir / f"steering_analysis{suffix}_{timestamp}.txt" 721 722 with open(filepath, "w") as f: 723 f.write(f"Prompt: {prompt}\n") 724 f.write("=" * 80 + "\n\n") 725 726 for label, text in results.items(): 727 f.write(f"--- {label} ---\n") 728 f.write(f"{text}\n\n") 729 f.write("-" * 40 + "\n") 730 731 logger.info(f"Saved steering analysis to {filepath}") 732 return filepath
Steer language model generation using trained SAE features.
Wraps a HuggingFace language model (via nnsight) and applies SAE feature interventions during generation. At each generation step, a forward hook on the target MLP module adds a scaled decoder weight vector, biasing the residual stream toward the chosen feature direction.
The HuggingFace auth token is read automatically from
HUGGINGFACE_HUB_TOKEN (via ~drrik.settings.get_settings())
when the token parameter is not explicitly provided.
Attributes:
- sae: The trained
~drrik.autoencoder.SparseAutoencoderproviding feature directions. - model: The nnsight
~nnsight.LanguageModelwrapper. - tokenizer: The HuggingFace tokenizer for the model.
- target_layer: The layer index where interventions are applied.
- device: The device the SAE parameters live on.
Example:
sae = SparseAutoencoder.load("sae_model.pt") steering = SAESteering(sae, model_name="google/gemma-2b", layer=5) # Steered generation result = steering.generate( "The sky is", feature_idx=128, strength=2.5, max_new_tokens=50, ) # Baseline (no steering) baseline = steering.generate("The sky is", max_new_tokens=50) # Multi-feature steering result = steering.generate( "The sky is", feature_indices=[10, 128], strengths=[1.0, 2.5], )
182 def __init__( 183 self, 184 sae: SparseAutoencoder, 185 model_name: str, 186 layer: int, 187 revision: str = "main", 188 torch_dtype: str = "float16", 189 device_map: str = "auto", 190 trust_remote_code: bool = True, 191 token: Optional[str] = None, 192 ): 193 """ 194 Initialize the SAE steering controller. 195 196 If ``token`` is not provided, it is read from the environment via 197 ``get_settings()`` (i.e. ``HUGGINGFACE_HUB_TOKEN`` in ``.env``). 198 199 Args: 200 sae: A trained SparseAutoencoder whose decoder weights provide 201 steering directions. 202 model_name: HuggingFace model identifier for generation (e.g., 203 ``"google/gemma-2b"``). 204 layer: The transformer layer index where MLP activations are 205 intercepted and modified. 206 revision: Model revision to load. 207 torch_dtype: Weight dtype for model loading. 208 device_map: Device mapping strategy. 209 trust_remote_code: Whether to trust remote code from the repo. 210 token: HuggingFace token for gated models. Falls back to 211 ``HUGGINGFACE_HUB_TOKEN`` from ``.env`` when ``None``. 212 """ 213 if token is None: 214 settings = get_settings() 215 token = settings.huggingface_hub_token 216 217 self.sae = sae 218 self.target_layer = layer 219 self.device = next(sae.parameters()).device 220 221 logger.info(f"Loading model '{model_name}' for steering at layer {layer}") 222 223 dtype_map = { 224 "float16": torch.float16, 225 "bfloat16": torch.bfloat16, 226 "float32": torch.float32, 227 } 228 load_dtype = dtype_map.get(torch_dtype, torch.float16) 229 230 model_kwargs = { 231 "revision": revision, 232 "torch_dtype": load_dtype, 233 "device_map": device_map, 234 "trust_remote_code": trust_remote_code, 235 } 236 if token: 237 model_kwargs["token"] = token 238 239 self.model = LanguageModel(model_name, **model_kwargs) 240 self.tokenizer = AutoTokenizer.from_pretrained( 241 model_name, 242 revision=revision, 243 trust_remote_code=trust_remote_code, 244 token=token, 245 ) 246 247 if self.tokenizer.pad_token is None: 248 self.tokenizer.pad_token = self.tokenizer.eos_token 249 250 logger.info(f"Model loaded on {self.model.device}")
Initialize the SAE steering controller.
If token is not provided, it is read from the environment via
get_settings() (i.e. HUGGINGFACE_HUB_TOKEN in .env).
Arguments:
- sae: A trained SparseAutoencoder whose decoder weights provide steering directions.
- model_name: HuggingFace model identifier for generation (e.g.,
"google/gemma-2b"). - layer: The transformer layer index where MLP activations are intercepted and modified.
- revision: Model revision to load.
- torch_dtype: Weight dtype for model loading.
- device_map: Device mapping strategy.
- trust_remote_code: Whether to trust remote code from the repo.
- token: HuggingFace token for gated models. Falls back to
HUGGINGFACE_HUB_TOKENfrom.envwhenNone.
366 def generate( 367 self, 368 text: str, 369 feature_idx: Optional[int] = None, 370 strength: float = 1.0, 371 max_new_tokens: int = 50, 372 temperature: float = 0.8, 373 top_p: float = 0.9, 374 feature_indices: Optional[List[int]] = None, 375 strengths: Optional[List[float]] = None, 376 ) -> str: 377 """ 378 Generate text with optional SAE feature steering. 379 380 This is the main entry point for generation. 381 382 - **No features** → calls :meth:`_generate_baseline` (plain 383 autoregressive generation). 384 - **Single feature** → pass ``feature_idx`` and ``strength``. 385 - **Multiple features** → pass ``feature_indices`` and 386 ``strengths`` as equal-length lists; the weighted decoder 387 directions are summed into a single combined vector. 388 389 Args: 390 text: The prompt text to generate from. 391 feature_idx: Index of a single SAE feature to steer with. 392 strength: Steering magnitude for a single feature 393 (default ``1.0``). 394 max_new_tokens: Maximum number of tokens to generate. 395 temperature: Sampling temperature. 396 top_p: Nucleus sampling threshold. 397 feature_indices: List of feature indices for multi-feature 398 steering. 399 strengths: List of steering magnitudes, one per entry in 400 ``feature_indices``. Defaults to ``[1.0, …]`` if 401 ``None``. 402 403 Returns: 404 The generated text string (special tokens stripped). 405 406 Raises: 407 ValueError: If ``feature_indices`` and ``strengths`` have 408 mismatched lengths, or a feature index is out of range. 409 410 Example: 411 ```python 412 # Single feature 413 steering.generate("Hello", feature_idx=42, strength=2.0) 414 415 # Multi-feature 416 steering.generate( 417 "Hello", 418 feature_indices=[10, 42], 419 strengths=[1.0, 3.0], 420 ) 421 422 # Baseline 423 steering.generate("Hello", max_new_tokens=100) 424 ``` 425 """ 426 if feature_idx is None and feature_indices is None: 427 return self._generate_baseline(text, max_new_tokens, temperature, top_p) 428 429 if feature_idx is not None: 430 feature_indices = [feature_idx] 431 strengths = [strength] 432 elif strengths is None: 433 strengths = [1.0] * len(feature_indices) 434 435 if len(feature_indices) != len(strengths): 436 raise ValueError("feature_indices and strengths must have the same length") 437 438 combined_direction = torch.zeros(self.sae.activation_dim, device=self.device) 439 for fid, s in zip(feature_indices, strengths): 440 if fid < 0 or fid >= self.sae.hidden_dim: 441 raise ValueError( 442 f"feature_idx {fid} out of range [0, {self.sae.hidden_dim})" 443 ) 444 combined_direction += s * self.sae.decoder.weight[:, fid] 445 446 input_ids, attention_mask = self._tokenize(text) 447 448 return self._generate_with_hooks( 449 input_ids, 450 attention_mask, 451 combined_direction, 452 1.0, 453 max_new_tokens, 454 temperature, 455 top_p, 456 )
Generate text with optional SAE feature steering.
This is the main entry point for generation.
- No features → calls
_generate_baseline()(plain autoregressive generation). - Single feature → pass
feature_idxandstrength. - Multiple features → pass
feature_indicesandstrengthsas equal-length lists; the weighted decoder directions are summed into a single combined vector.
Arguments:
- text: The prompt text to generate from.
- feature_idx: Index of a single SAE feature to steer with.
- strength: Steering magnitude for a single feature
(default
1.0). - max_new_tokens: Maximum number of tokens to generate.
- temperature: Sampling temperature.
- top_p: Nucleus sampling threshold.
- feature_indices: List of feature indices for multi-feature steering.
- strengths: List of steering magnitudes, one per entry in
feature_indices. Defaults to[1.0, …]ifNone.
Returns:
The generated text string (special tokens stripped).
Raises:
- ValueError: If
feature_indicesandstrengthshave mismatched lengths, or a feature index is out of range.
Example:
# Single feature steering.generate("Hello", feature_idx=42, strength=2.0) # Multi-feature steering.generate( "Hello", feature_indices=[10, 42], strengths=[1.0, 3.0], ) # Baseline steering.generate("Hello", max_new_tokens=100)
515 def compare_steering( 516 self, 517 text: str, 518 feature_idx: int, 519 strengths: Optional[List[float]] = None, 520 max_new_tokens: int = 50, 521 temperature: float = 0.8, 522 top_p: float = 0.9, 523 ) -> Dict[str, str]: 524 """ 525 Compare baseline generation against steered generations at different strengths. 526 527 Runs one generation per strength value (including ``0.0`` for a 528 no-steering baseline) and returns a dict mapping descriptive 529 labels to the generated text. 530 531 Args: 532 text: The prompt text. 533 feature_idx: SAE feature index to steer with. 534 strengths: List of strength values to test. Defaults to 535 ``[0.0, 1.0, 2.0, 3.0, 5.0]``. 536 max_new_tokens: Maximum tokens to generate per run. 537 temperature: Sampling temperature. 538 top_p: Nucleus sampling threshold. 539 540 Returns: 541 Dictionary mapping strength labels to generated text. 542 Keys are ``"baseline"`` (for strength ``0.0``) and 543 ``"strength_1.0"``, ``"strength_2.0"``, etc. 544 545 Example: 546 ```python 547 results = steering.compare_steering( 548 "The sky is", 549 feature_idx=42, 550 strengths=[0.0, 1.0, 3.0], 551 ) 552 for label, text in results.items(): 553 print(f"{label}: {text}") 554 ``` 555 """ 556 if strengths is None: 557 strengths = [0.0, 1.0, 2.0, 3.0, 5.0] 558 559 results = {} 560 561 for s in strengths: 562 if s == 0.0: 563 label = "baseline" 564 else: 565 label = f"strength_{s}" 566 567 logger.info(f"Generating with {label} (strength={s})") 568 generated = self.generate( 569 text, 570 feature_idx=feature_idx, 571 strength=s, 572 max_new_tokens=max_new_tokens, 573 temperature=temperature, 574 top_p=top_p, 575 ) 576 results[label] = generated 577 578 return results
Compare baseline generation against steered generations at different strengths.
Runs one generation per strength value (including 0.0 for a
no-steering baseline) and returns a dict mapping descriptive
labels to the generated text.
Arguments:
- text: The prompt text.
- feature_idx: SAE feature index to steer with.
- strengths: List of strength values to test. Defaults to
[0.0, 1.0, 2.0, 3.0, 5.0]. - max_new_tokens: Maximum tokens to generate per run.
- temperature: Sampling temperature.
- top_p: Nucleus sampling threshold.
Returns:
Dictionary mapping strength labels to generated text. Keys are
"baseline"(for strength0.0) and"strength_1.0","strength_2.0", etc.
Example:
results = steering.compare_steering( "The sky is", feature_idx=42, strengths=[0.0, 1.0, 3.0], ) for label, text in results.items(): print(f"{label}: {text}")
580 def find_steering_features( 581 self, 582 text: str, 583 activations: np.ndarray, 584 top_k: int = 20, 585 min_activation: float = 0.0, 586 ) -> List[Dict[str, Any]]: 587 """ 588 Find the top-k SAE features that most activate on given activations. 589 590 Encodes the provided activations through the SAE and returns the 591 features with the highest mean activation values, sorted by 592 magnitude. Useful for identifying which features are relevant 593 to a given input before steering. 594 595 Args: 596 text: The prompt text (used for logging only). 597 activations: MLP activations array of shape 598 ``(n_samples, activation_dim)``. 599 top_k: Number of top features to return. 600 min_activation: Minimum mean activation threshold to 601 consider a feature. 602 603 Returns: 604 List of dicts sorted by descending activation, each 605 containing: 606 607 - ``feature_idx`` (``int``): The feature index. 608 - ``activation`` (``float``): Mean activation value. 609 - ``normalized_activation`` (``float``): Activation 610 divided by the decoder weight L2 norm. 611 - ``decoder_weight_norm`` (``float``): L2 norm of the 612 decoder weight vector. 613 """ 614 self.sae.eval() 615 device = self.device 616 617 with torch.no_grad(): 618 if isinstance(activations, np.ndarray): 619 acts_tensor = torch.from_numpy(activations).float().to(device) 620 else: 621 acts_tensor = activations.to(device) 622 623 features = self.sae.encode(acts_tensor) 624 625 avg_activations = features.mean(dim=0).cpu().numpy() 626 decoder_norms = self.sae.decoder.weight.norm(dim=0).cpu().numpy() 627 628 active_mask = avg_activations >= min_activation 629 active_indices = np.where(active_mask)[0] 630 631 sorted_indices = active_indices[ 632 np.argsort(avg_activations[active_indices])[::-1] 633 ][:top_k] 634 635 results = [] 636 for idx in sorted_indices: 637 results.append( 638 { 639 "feature_idx": int(idx), 640 "activation": float(avg_activations[idx]), 641 "normalized_activation": float( 642 avg_activations[idx] / (decoder_norms[idx] + 1e-8) 643 ), 644 "decoder_weight_norm": float(decoder_norms[idx]), 645 } 646 ) 647 648 logger.info(f"Found {len(results)} active features for text: {text[:100]}") 649 return results
Find the top-k SAE features that most activate on given activations.
Encodes the provided activations through the SAE and returns the features with the highest mean activation values, sorted by magnitude. Useful for identifying which features are relevant to a given input before steering.
Arguments:
- text: The prompt text (used for logging only).
- activations: MLP activations array of shape
(n_samples, activation_dim). - top_k: Number of top features to return.
- min_activation: Minimum mean activation threshold to consider a feature.
Returns:
List of dicts sorted by descending activation, each containing:
feature_idx(int): The feature index.activation(float): Mean activation value.normalized_activation(float): Activation divided by the decoder weight L2 norm.decoder_weight_norm(float): L2 norm of the decoder weight vector.
651 def get_steering_direction( 652 self, 653 feature_idx: int, 654 normalize: bool = True, 655 ) -> np.ndarray: 656 """ 657 Get the steering direction vector for a given feature. 658 659 The steering direction is the SAE decoder weight column for the 660 specified feature. This vector represents the direction in MLP 661 activation space that the feature encodes. 662 663 Args: 664 feature_idx: Index of the SAE feature. 665 normalize: If ``True``, L2-normalise the direction vector. 666 667 Returns: 668 NumPy array of shape ``(activation_dim,)`` containing the 669 steering direction. 670 671 Raises: 672 ValueError: If ``feature_idx`` is out of range 673 ``[0, sae.hidden_dim)``. 674 """ 675 if feature_idx < 0 or feature_idx >= self.sae.hidden_dim: 676 raise ValueError( 677 f"feature_idx {feature_idx} out of range [0, {self.sae.hidden_dim})" 678 ) 679 680 direction = self.sae.decoder.weight[:, feature_idx].detach().cpu().numpy() 681 682 if normalize: 683 norm = np.linalg.norm(direction) 684 if norm > 0: 685 direction = direction / norm 686 687 return direction
Get the steering direction vector for a given feature.
The steering direction is the SAE decoder weight column for the specified feature. This vector represents the direction in MLP activation space that the feature encodes.
Arguments:
- feature_idx: Index of the SAE feature.
- normalize: If
True, L2-normalise the direction vector.
Returns:
NumPy array of shape
(activation_dim,)containing the steering direction.
Raises:
- ValueError: If
feature_idxis out of range[0, sae.hidden_dim).
689 def save_steering_analysis( 690 self, 691 results: Dict[str, str], 692 output_dir: Union[str, Path], 693 prompt: str = "", 694 feature_label: str = "", 695 ) -> Path: 696 """ 697 Save steering comparison results to disk. 698 699 Writes a text file containing the prompt and all generated 700 outputs (baseline and steered) for later review. The 701 filename includes a timestamp (and optional feature label) 702 to prevent overwriting previous analyses. 703 704 Args: 705 results: Dictionary mapping labels to generated text 706 (as returned by :meth:`compare_steering`). 707 output_dir: Directory to save the analysis file. 708 prompt: The original prompt text. 709 feature_label: Optional label included in the filename 710 (e.g. ``"feature_42"``). 711 712 Returns: 713 Path to the saved analysis file. 714 """ 715 output_dir = Path(output_dir) 716 output_dir.mkdir(parents=True, exist_ok=True) 717 718 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 719 suffix = f"_{feature_label}" if feature_label else "" 720 filepath = output_dir / f"steering_analysis{suffix}_{timestamp}.txt" 721 722 with open(filepath, "w") as f: 723 f.write(f"Prompt: {prompt}\n") 724 f.write("=" * 80 + "\n\n") 725 726 for label, text in results.items(): 727 f.write(f"--- {label} ---\n") 728 f.write(f"{text}\n\n") 729 f.write("-" * 40 + "\n") 730 731 logger.info(f"Saved steering analysis to {filepath}") 732 return filepath
Save steering comparison results to disk.
Writes a text file containing the prompt and all generated outputs (baseline and steered) for later review. The filename includes a timestamp (and optional feature label) to prevent overwriting previous analyses.
Arguments:
- results: Dictionary mapping labels to generated text
(as returned by
compare_steering()). - output_dir: Directory to save the analysis file.
- prompt: The original prompt text.
- feature_label: Optional label included in the filename
(e.g.
"feature_42").
Returns:
Path to the saved analysis file.
276class Config(BaseSettings): 277 """Main configuration class for the Drrik framework. 278 279 Aggregates all sub-configurations (extractor, autoencoder, visualization) 280 into a single object. Settings can be loaded from environment variables 281 (with ``DRIK_`` prefix) or from a YAML config file via the CLI. 282 283 Attributes: 284 extractor: Activation extraction configuration. 285 autoencoder: Sparse autoencoder configuration. 286 visualization: Feature visualization configuration. 287 random_seed: Global random seed for reproducibility. Applied 288 to numpy, torch, and random at pipeline start. 289 log_level: Logging verbosity. One of ``DEBUG``, ``INFO``, 290 ``WARNING``, ``ERROR``, ``CRITICAL``. 291 292 Example: 293 Load from environment variables: 294 295 ```bash 296 export DRIK_RANDOM_SEED=42 297 export DRIK_AUTOENCODER__HIDDEN_DIM=16384 298 config = Config() 299 ``` 300 301 Or construct programmatically: 302 303 ```python 304 config = Config( 305 extractor=ActivationExtractorConfig(...), 306 autoencoder=SparseAutoencoderConfig(activation_dim=2048), 307 ) 308 ``` 309 """ 310 311 extractor: ActivationExtractorConfig = Field( 312 default_factory=ActivationExtractorConfig 313 ) 314 autoencoder: SparseAutoencoderConfig = Field( 315 default_factory=SparseAutoencoderConfig 316 ) 317 visualization: VisualizationConfig = Field(default_factory=VisualizationConfig) 318 319 random_seed: int = Field(default=42, description="Random seed for reproducibility") 320 log_level: str = Field( 321 default="INFO", description="Logging level (DEBUG, INFO, WARNING, ERROR)" 322 ) 323 324 model_config = ConfigDict( 325 env_prefix="DRIK_", 326 env_nested_delimiter="__", 327 ) 328 329 def create_output_dirs(self) -> None: 330 """Create output directories if they don't exist. 331 332 Ensures that both the extractor output directory (if configured) 333 and the visualization output directory exist on disk, creating 334 any missing parent directories as needed. 335 """ 336 if self.extractor.output_dir: 337 self.extractor.output_dir.mkdir(parents=True, exist_ok=True) 338 self.visualization.output_dir.mkdir(parents=True, exist_ok=True)
Main configuration class for the Drrik framework.
Aggregates all sub-configurations (extractor, autoencoder, visualization)
into a single object. Settings can be loaded from environment variables
(with DRIK_ prefix) or from a YAML config file via the CLI.
Attributes:
- extractor: Activation extraction configuration.
- autoencoder: Sparse autoencoder configuration.
- visualization: Feature visualization configuration.
- random_seed: Global random seed for reproducibility. Applied to numpy, torch, and random at pipeline start.
- log_level: Logging verbosity. One of
DEBUG,INFO,WARNING,ERROR,CRITICAL.
Example:
Load from environment variables:
export DRIK_RANDOM_SEED=42 export DRIK_AUTOENCODER__HIDDEN_DIM=16384 config = Config()Or construct programmatically:
config = Config( extractor=ActivationExtractorConfig(...), autoencoder=SparseAutoencoderConfig(activation_dim=2048), )
329 def create_output_dirs(self) -> None: 330 """Create output directories if they don't exist. 331 332 Ensures that both the extractor output directory (if configured) 333 and the visualization output directory exist on disk, creating 334 any missing parent directories as needed. 335 """ 336 if self.extractor.output_dir: 337 self.extractor.output_dir.mkdir(parents=True, exist_ok=True) 338 self.visualization.output_dir.mkdir(parents=True, exist_ok=True)
Create output directories if they don't exist.
Ensures that both the extractor output directory (if configured) and the visualization output directory exist on disk, creating any missing parent directories as needed.
32class EnvironmentSettings(BaseSettings): 33 """ 34 Environment settings for API keys and tokens. 35 36 These settings are loaded from environment variables or a .env file. 37 Create a .env file in the project root with your credentials. 38 39 Example .env file: 40 ```bash 41 HUGGINGFACE_HUB_TOKEN=hf_... 42 WANDB_API_KEY=... 43 WANDB_PROJECT=drrik-experiments 44 WANDB_ENTITY=your-username 45 ``` 46 47 Attributes: 48 huggingface_hub_token: HuggingFace Hub API token for gated models 49 wandb_api_key: Weights & Biases API key for experiment tracking 50 wandb_project: Default wandb project name 51 wandb_entity: Default wandb entity (username or team) 52 wandb_mode: wandb mode ('online', 'offline', or 'disabled') 53 """ 54 55 model_config = SettingsConfigDict( 56 env_file=".env", 57 env_file_encoding="utf-8", 58 env_prefix="", 59 case_sensitive=False, 60 extra="ignore", 61 ) 62 63 huggingface_hub_token: Optional[str] = Field( 64 default=None, 65 description="HuggingFace Hub API token for accessing gated models. " 66 "Get your token at: https://huggingface.co/settings/tokens", 67 ) 68 69 wandb_api_key: Optional[str] = Field( 70 default=None, 71 description="Weights & Biases API key for experiment tracking. " 72 "Get your key at: https://wandb.ai/settings", 73 ) 74 75 wandb_project: str = Field( 76 default="drrik-experiments", description="Default wandb project name" 77 ) 78 79 wandb_entity: Optional[str] = Field( 80 default=None, description="Default wandb entity (username or team name)" 81 ) 82 83 wandb_mode: str = Field( 84 default="online", 85 description="wandb mode: 'online' to sync, 'offline' to save locally, " 86 "'disabled' to disable wandb", 87 ) 88 89 @field_validator("wandb_mode") 90 @classmethod 91 def validate_wandb_mode(cls, v: str) -> str: 92 """Validate that the wandb mode is one of the allowed values. 93 94 Args: 95 v: The wandb mode string to validate. 96 97 Returns: 98 The lowercased wandb mode string. 99 100 Raises: 101 ValueError: If the mode is not ``online``, ``offline``, 102 or ``disabled``. 103 """ 104 valid_modes = ["online", "offline", "disabled"] 105 v = v.lower() 106 if v not in valid_modes: 107 raise ValueError(f"wandb_mode must be one of {valid_modes}, got '{v}'") 108 return v 109 110 @field_validator("huggingface_hub_token") 111 @classmethod 112 def validate_hf_token(cls, v: Optional[str]) -> Optional[str]: 113 """Log a confirmation message when an HF token is provided. 114 115 Args: 116 v: The HuggingFace Hub token string, or ``None``. 117 118 Returns: 119 The unmodified token string. 120 """ 121 if v: 122 logger.info("HuggingFace Hub token is configured") 123 return v 124 125 @field_validator("wandb_api_key") 126 @classmethod 127 def validate_wandb_key(cls, v: Optional[str]) -> Optional[str]: 128 """Log a confirmation message when a wandb API key is provided. 129 130 Args: 131 v: The Weights & Biases API key string, or ``None``. 132 133 Returns: 134 The unmodified API key string. 135 """ 136 if v: 137 logger.info("Weights & Biases API key is configured") 138 return v 139 140 @property 141 def use_wandb(self) -> bool: 142 """Check if wandb should be enabled based on API key and mode. 143 144 Returns: 145 ``True`` if an API key is set and wandb mode is not 146 ``disabled``, ``False`` otherwise. 147 """ 148 return self.wandb_api_key is not None and self.wandb_mode != "disabled" 149 150 @property 151 def has_hf_token(self) -> bool: 152 """Check if a HuggingFace Hub token is available. 153 154 Returns: 155 ``True`` if ``huggingface_hub_token`` is set, ``False`` 156 otherwise. 157 """ 158 return self.huggingface_hub_token is not None 159 160 def get_hf_auth(self) -> Optional[tuple]: 161 """Get HuggingFace authentication tuple for model loading. 162 163 Constructs the authentication argument expected by the 164 ``transformers`` library when loading gated models. 165 166 Returns: 167 A tuple of ``(True, token_string)`` if a token is 168 configured, otherwise ``None``. 169 """ 170 if self.has_hf_token: 171 return (True, self.huggingface_hub_token) 172 return None
Environment settings for API keys and tokens.
These settings are loaded from environment variables or a .env file. Create a .env file in the project root with your credentials.
Example .env file:
HUGGINGFACE_HUB_TOKEN=hf_...
WANDB_API_KEY=...
WANDB_PROJECT=drrik-experiments
WANDB_ENTITY=your-username
Attributes:
- huggingface_hub_token: HuggingFace Hub API token for gated models
- wandb_api_key: Weights & Biases API key for experiment tracking
- wandb_project: Default wandb project name
- wandb_entity: Default wandb entity (username or team)
- wandb_mode: wandb mode ('online', 'offline', or 'disabled')
HuggingFace Hub API token for accessing gated models. Get your token at: https://huggingface.co/settings/tokens
Weights & Biases API key for experiment tracking. Get your key at: https://wandb.ai/settings
wandb mode: 'online' to sync, 'offline' to save locally, 'disabled' to disable wandb
89 @field_validator("wandb_mode") 90 @classmethod 91 def validate_wandb_mode(cls, v: str) -> str: 92 """Validate that the wandb mode is one of the allowed values. 93 94 Args: 95 v: The wandb mode string to validate. 96 97 Returns: 98 The lowercased wandb mode string. 99 100 Raises: 101 ValueError: If the mode is not ``online``, ``offline``, 102 or ``disabled``. 103 """ 104 valid_modes = ["online", "offline", "disabled"] 105 v = v.lower() 106 if v not in valid_modes: 107 raise ValueError(f"wandb_mode must be one of {valid_modes}, got '{v}'") 108 return v
Validate that the wandb mode is one of the allowed values.
Arguments:
- v: The wandb mode string to validate.
Returns:
The lowercased wandb mode string.
Raises:
- ValueError: If the mode is not
online,offline, ordisabled.
110 @field_validator("huggingface_hub_token") 111 @classmethod 112 def validate_hf_token(cls, v: Optional[str]) -> Optional[str]: 113 """Log a confirmation message when an HF token is provided. 114 115 Args: 116 v: The HuggingFace Hub token string, or ``None``. 117 118 Returns: 119 The unmodified token string. 120 """ 121 if v: 122 logger.info("HuggingFace Hub token is configured") 123 return v
Log a confirmation message when an HF token is provided.
Arguments:
- v: The HuggingFace Hub token string, or
None.
Returns:
The unmodified token string.
125 @field_validator("wandb_api_key") 126 @classmethod 127 def validate_wandb_key(cls, v: Optional[str]) -> Optional[str]: 128 """Log a confirmation message when a wandb API key is provided. 129 130 Args: 131 v: The Weights & Biases API key string, or ``None``. 132 133 Returns: 134 The unmodified API key string. 135 """ 136 if v: 137 logger.info("Weights & Biases API key is configured") 138 return v
Log a confirmation message when a wandb API key is provided.
Arguments:
- v: The Weights & Biases API key string, or
None.
Returns:
The unmodified API key string.
140 @property 141 def use_wandb(self) -> bool: 142 """Check if wandb should be enabled based on API key and mode. 143 144 Returns: 145 ``True`` if an API key is set and wandb mode is not 146 ``disabled``, ``False`` otherwise. 147 """ 148 return self.wandb_api_key is not None and self.wandb_mode != "disabled"
Check if wandb should be enabled based on API key and mode.
Returns:
Trueif an API key is set and wandb mode is notdisabled,Falseotherwise.
150 @property 151 def has_hf_token(self) -> bool: 152 """Check if a HuggingFace Hub token is available. 153 154 Returns: 155 ``True`` if ``huggingface_hub_token`` is set, ``False`` 156 otherwise. 157 """ 158 return self.huggingface_hub_token is not None
Check if a HuggingFace Hub token is available.
Returns:
Trueifhuggingface_hub_tokenis set,Falseotherwise.
160 def get_hf_auth(self) -> Optional[tuple]: 161 """Get HuggingFace authentication tuple for model loading. 162 163 Constructs the authentication argument expected by the 164 ``transformers`` library when loading gated models. 165 166 Returns: 167 A tuple of ``(True, token_string)`` if a token is 168 configured, otherwise ``None``. 169 """ 170 if self.has_hf_token: 171 return (True, self.huggingface_hub_token) 172 return None
Get HuggingFace authentication tuple for model loading.
Constructs the authentication argument expected by the
transformers library when loading gated models.
Returns:
A tuple of
(True, token_string)if a token is configured, otherwiseNone.
175class WandbConfig: 176 """ 177 Configuration for wandb experiment tracking. 178 179 This class handles wandb initialization, logging, and cleanup. 180 It's designed to be used as a context manager for automatic cleanup. 181 182 Example: 183 ```python 184 from drrik.settings import WandbConfig 185 186 # Use as context manager 187 with WandbConfig( 188 project="my-sae-experiment", 189 config={"model": "gemma-2b", "expansion": 8} 190 ) as wandb_logger: 191 wandb_logger.log_metrics({"loss": 0.5, "l0_norm": 10}) 192 193 # Or manually initialize/finalize 194 wandb_config = WandbConfig(project="my-sae-experiment") 195 wandb_config.initialize() 196 wandb_config.log_metrics({"loss": 0.5}) 197 wandb_config.finalize() 198 ``` 199 """ 200 201 def __init__( 202 self, 203 project: Optional[str] = None, 204 entity: Optional[str] = None, 205 name: Optional[str] = None, 206 config: Optional[dict] = None, 207 tags: Optional[list[str]] = None, 208 settings: Optional[EnvironmentSettings] = None, 209 enabled: bool = True, 210 ): 211 """ 212 Initialize wandb configuration. 213 214 Args: 215 project: wandb project name (uses settings default if None) 216 entity: wandb entity (uses settings default if None) 217 name: Run name (auto-generated if None) 218 config: Configuration dict to log 219 tags: List of tags for the run 220 settings: EnvironmentSettings instance (uses global if None) 221 enabled: If False, disables wandb even if configured 222 """ 223 self.settings = settings or EnvironmentSettings() 224 self.enabled = enabled and self.settings.use_wandb 225 226 if self.enabled and not self.settings.wandb_api_key: 227 logger.warning( 228 "wandb is enabled but no API key is set. " 229 "Set WANDB_API_KEY environment variable or disable wandb." 230 ) 231 self.enabled = False 232 233 self.project = project or self.settings.wandb_project 234 self.entity = entity or self.settings.wandb_entity 235 self.name = name 236 self.config = config or {} 237 self.tags = tags 238 239 self._initialized = False 240 self._run = None 241 242 def initialize(self) -> bool: 243 """ 244 Initialize wandb run. 245 246 Returns: 247 True if wandb was successfully initialized, False otherwise 248 """ 249 if not self.enabled: 250 logger.info("wandb is disabled") 251 return False 252 253 if self._initialized: 254 logger.warning("wandb is already initialized") 255 return True 256 257 try: 258 import wandb 259 260 # Set API key 261 os.environ["WANDB_API_KEY"] = self.settings.wandb_api_key 262 os.environ["WANDB_MODE"] = self.settings.wandb_mode 263 264 # Initialize run 265 self._run = wandb.init( 266 project=self.project, 267 entity=self.entity, 268 name=self.name, 269 config=self.config, 270 tags=self.tags, 271 reinit=True, # Allow re-initialization 272 ) 273 274 self._initialized = True 275 logger.info(f"wandb initialized: {self.get_run_url()}") 276 return True 277 278 except ImportError: 279 logger.warning( 280 "wandb package not installed. Install it with: pip install wandb" 281 ) 282 self.enabled = False 283 return False 284 except Exception as e: 285 logger.error(f"Failed to initialize wandb: {e}") 286 self.enabled = False 287 return False 288 289 def finalize(self) -> None: 290 """Finalize wandb run.""" 291 if self._initialized: 292 try: 293 import wandb 294 295 wandb.finish() 296 logger.info("wandb run finalized") 297 except Exception as e: 298 logger.error(f"Error finalizing wandb: {e}") 299 finally: 300 self._initialized = False 301 self._run = None 302 303 def log_metrics( 304 self, 305 metrics: dict, 306 step: Optional[int] = None, 307 commit: bool = True, 308 ) -> None: 309 """ 310 Log metrics to wandb. 311 312 Args: 313 metrics: Dictionary of metric names to values 314 step: Current step (for logging to specific step) 315 commit: Whether to commit the log 316 """ 317 if self._initialized: 318 try: 319 import wandb 320 321 wandb.log(metrics, step=step, commit=commit) 322 except Exception as e: 323 logger.error(f"Error logging to wandb: {e}") 324 325 def log_histogram( 326 self, 327 values, 328 name: str, 329 step: Optional[int] = None, 330 ) -> None: 331 """ 332 Log a histogram to wandb. 333 334 Args: 335 values: Array-like values to histogram 336 name: Name of the histogram 337 step: Current step 338 """ 339 if self._initialized: 340 try: 341 import wandb 342 import numpy as np 343 344 wandb.log({name: wandb.Histogram(np.array(values))}, step=step) 345 except Exception as e: 346 logger.error(f"Error logging histogram to wandb: {e}") 347 348 def log_model(self, model_path: str, name: str = "model") -> None: 349 """ 350 Log a model artifact to wandb. 351 352 Supports both single files and directories. 353 354 Args: 355 model_path: Path to the model file or directory 356 name: Artifact name 357 """ 358 if self._initialized: 359 try: 360 import wandb 361 from pathlib import Path 362 363 artifact = wandb.Artifact(name, type="model") 364 path = Path(model_path) 365 if path.is_dir(): 366 artifact.add_dir(str(path), name=path.name) 367 else: 368 artifact.add_file(str(path)) 369 wandb.log_artifact(artifact) 370 logger.info(f"Logged model artifact: {name} from {model_path}") 371 except Exception as e: 372 logger.error(f"Error logging model to wandb: {e}") 373 374 def log_artifact( 375 self, path: str, name: str, artifact_type: str = "dataset" 376 ) -> None: 377 """ 378 Log a dataset artifact to wandb. 379 380 Supports both single files and directories. 381 382 Args: 383 path: Path to the file or directory 384 name: Artifact name 385 artifact_type: Type of artifact (dataset, model, config, etc.) 386 """ 387 if self._initialized: 388 try: 389 import wandb 390 from pathlib import Path 391 392 artifact = wandb.Artifact(name, type=artifact_type) 393 p = Path(path) 394 if p.is_dir(): 395 artifact.add_dir(str(p), name=p.name) 396 else: 397 artifact.add_file(str(p)) 398 wandb.log_artifact(artifact) 399 logger.info(f"Logged {artifact_type} artifact: {name} from {path}") 400 except Exception as e: 401 logger.error(f"Error logging artifact to wandb: {e}") 402 403 def get_run_url(self) -> Optional[str]: 404 """Get the wandb run URL for the current experiment. 405 406 Returns: 407 The URL string of the active wandb run, or ``None`` if 408 no run has been initialized. 409 """ 410 if self._run: 411 return self._run.url 412 return None 413 414 def get_run_id(self) -> Optional[str]: 415 """Get the wandb run ID for the current experiment. 416 417 Returns: 418 The unique run identifier string, or ``None`` if no run 419 has been initialized. 420 """ 421 if self._run: 422 return self._run.id 423 return None 424 425 def __enter__(self): 426 """Enter the context manager, initializing the wandb run. 427 428 Returns: 429 The ``WandbConfig`` instance for use within the ``with`` 430 block. 431 """ 432 self.initialize() 433 return self 434 435 def __exit__(self, exc_type, exc_val, exc_tb): 436 """Exit the context manager, finalizing the wandb run. 437 438 Args: 439 exc_type: Exception type, if an exception was raised. 440 exc_val: Exception value, if an exception was raised. 441 exc_tb: Exception traceback, if an exception was raised. 442 443 Returns: 444 ``False`` to propagate any exceptions. 445 """ 446 self.finalize() 447 return False
Configuration for wandb experiment tracking.
This class handles wandb initialization, logging, and cleanup. It's designed to be used as a context manager for automatic cleanup.
Example:
from drrik.settings import WandbConfig # Use as context manager with WandbConfig( project="my-sae-experiment", config={"model": "gemma-2b", "expansion": 8} ) as wandb_logger: wandb_logger.log_metrics({"loss": 0.5, "l0_norm": 10}) # Or manually initialize/finalize wandb_config = WandbConfig(project="my-sae-experiment") wandb_config.initialize() wandb_config.log_metrics({"loss": 0.5}) wandb_config.finalize()
201 def __init__( 202 self, 203 project: Optional[str] = None, 204 entity: Optional[str] = None, 205 name: Optional[str] = None, 206 config: Optional[dict] = None, 207 tags: Optional[list[str]] = None, 208 settings: Optional[EnvironmentSettings] = None, 209 enabled: bool = True, 210 ): 211 """ 212 Initialize wandb configuration. 213 214 Args: 215 project: wandb project name (uses settings default if None) 216 entity: wandb entity (uses settings default if None) 217 name: Run name (auto-generated if None) 218 config: Configuration dict to log 219 tags: List of tags for the run 220 settings: EnvironmentSettings instance (uses global if None) 221 enabled: If False, disables wandb even if configured 222 """ 223 self.settings = settings or EnvironmentSettings() 224 self.enabled = enabled and self.settings.use_wandb 225 226 if self.enabled and not self.settings.wandb_api_key: 227 logger.warning( 228 "wandb is enabled but no API key is set. " 229 "Set WANDB_API_KEY environment variable or disable wandb." 230 ) 231 self.enabled = False 232 233 self.project = project or self.settings.wandb_project 234 self.entity = entity or self.settings.wandb_entity 235 self.name = name 236 self.config = config or {} 237 self.tags = tags 238 239 self._initialized = False 240 self._run = None
Initialize wandb configuration.
Arguments:
- project: wandb project name (uses settings default if None)
- entity: wandb entity (uses settings default if None)
- name: Run name (auto-generated if None)
- config: Configuration dict to log
- tags: List of tags for the run
- settings: EnvironmentSettings instance (uses global if None)
- enabled: If False, disables wandb even if configured
242 def initialize(self) -> bool: 243 """ 244 Initialize wandb run. 245 246 Returns: 247 True if wandb was successfully initialized, False otherwise 248 """ 249 if not self.enabled: 250 logger.info("wandb is disabled") 251 return False 252 253 if self._initialized: 254 logger.warning("wandb is already initialized") 255 return True 256 257 try: 258 import wandb 259 260 # Set API key 261 os.environ["WANDB_API_KEY"] = self.settings.wandb_api_key 262 os.environ["WANDB_MODE"] = self.settings.wandb_mode 263 264 # Initialize run 265 self._run = wandb.init( 266 project=self.project, 267 entity=self.entity, 268 name=self.name, 269 config=self.config, 270 tags=self.tags, 271 reinit=True, # Allow re-initialization 272 ) 273 274 self._initialized = True 275 logger.info(f"wandb initialized: {self.get_run_url()}") 276 return True 277 278 except ImportError: 279 logger.warning( 280 "wandb package not installed. Install it with: pip install wandb" 281 ) 282 self.enabled = False 283 return False 284 except Exception as e: 285 logger.error(f"Failed to initialize wandb: {e}") 286 self.enabled = False 287 return False
Initialize wandb run.
Returns:
True if wandb was successfully initialized, False otherwise
289 def finalize(self) -> None: 290 """Finalize wandb run.""" 291 if self._initialized: 292 try: 293 import wandb 294 295 wandb.finish() 296 logger.info("wandb run finalized") 297 except Exception as e: 298 logger.error(f"Error finalizing wandb: {e}") 299 finally: 300 self._initialized = False 301 self._run = None
Finalize wandb run.
303 def log_metrics( 304 self, 305 metrics: dict, 306 step: Optional[int] = None, 307 commit: bool = True, 308 ) -> None: 309 """ 310 Log metrics to wandb. 311 312 Args: 313 metrics: Dictionary of metric names to values 314 step: Current step (for logging to specific step) 315 commit: Whether to commit the log 316 """ 317 if self._initialized: 318 try: 319 import wandb 320 321 wandb.log(metrics, step=step, commit=commit) 322 except Exception as e: 323 logger.error(f"Error logging to wandb: {e}")
Log metrics to wandb.
Arguments:
- metrics: Dictionary of metric names to values
- step: Current step (for logging to specific step)
- commit: Whether to commit the log
325 def log_histogram( 326 self, 327 values, 328 name: str, 329 step: Optional[int] = None, 330 ) -> None: 331 """ 332 Log a histogram to wandb. 333 334 Args: 335 values: Array-like values to histogram 336 name: Name of the histogram 337 step: Current step 338 """ 339 if self._initialized: 340 try: 341 import wandb 342 import numpy as np 343 344 wandb.log({name: wandb.Histogram(np.array(values))}, step=step) 345 except Exception as e: 346 logger.error(f"Error logging histogram to wandb: {e}")
Log a histogram to wandb.
Arguments:
- values: Array-like values to histogram
- name: Name of the histogram
- step: Current step
348 def log_model(self, model_path: str, name: str = "model") -> None: 349 """ 350 Log a model artifact to wandb. 351 352 Supports both single files and directories. 353 354 Args: 355 model_path: Path to the model file or directory 356 name: Artifact name 357 """ 358 if self._initialized: 359 try: 360 import wandb 361 from pathlib import Path 362 363 artifact = wandb.Artifact(name, type="model") 364 path = Path(model_path) 365 if path.is_dir(): 366 artifact.add_dir(str(path), name=path.name) 367 else: 368 artifact.add_file(str(path)) 369 wandb.log_artifact(artifact) 370 logger.info(f"Logged model artifact: {name} from {model_path}") 371 except Exception as e: 372 logger.error(f"Error logging model to wandb: {e}")
Log a model artifact to wandb.
Supports both single files and directories.
Arguments:
- model_path: Path to the model file or directory
- name: Artifact name
374 def log_artifact( 375 self, path: str, name: str, artifact_type: str = "dataset" 376 ) -> None: 377 """ 378 Log a dataset artifact to wandb. 379 380 Supports both single files and directories. 381 382 Args: 383 path: Path to the file or directory 384 name: Artifact name 385 artifact_type: Type of artifact (dataset, model, config, etc.) 386 """ 387 if self._initialized: 388 try: 389 import wandb 390 from pathlib import Path 391 392 artifact = wandb.Artifact(name, type=artifact_type) 393 p = Path(path) 394 if p.is_dir(): 395 artifact.add_dir(str(p), name=p.name) 396 else: 397 artifact.add_file(str(p)) 398 wandb.log_artifact(artifact) 399 logger.info(f"Logged {artifact_type} artifact: {name} from {path}") 400 except Exception as e: 401 logger.error(f"Error logging artifact to wandb: {e}")
Log a dataset artifact to wandb.
Supports both single files and directories.
Arguments:
- path: Path to the file or directory
- name: Artifact name
- artifact_type: Type of artifact (dataset, model, config, etc.)
403 def get_run_url(self) -> Optional[str]: 404 """Get the wandb run URL for the current experiment. 405 406 Returns: 407 The URL string of the active wandb run, or ``None`` if 408 no run has been initialized. 409 """ 410 if self._run: 411 return self._run.url 412 return None
Get the wandb run URL for the current experiment.
Returns:
The URL string of the active wandb run, or
Noneif no run has been initialized.
414 def get_run_id(self) -> Optional[str]: 415 """Get the wandb run ID for the current experiment. 416 417 Returns: 418 The unique run identifier string, or ``None`` if no run 419 has been initialized. 420 """ 421 if self._run: 422 return self._run.id 423 return None
Get the wandb run ID for the current experiment.
Returns:
The unique run identifier string, or
Noneif no run has been initialized.
454def get_settings() -> EnvironmentSettings: 455 """ 456 Get the global EnvironmentSettings instance. 457 458 Returns: 459 The global settings (creates one if it doesn't exist) 460 """ 461 global _global_settings 462 if _global_settings is None: 463 _global_settings = EnvironmentSettings() 464 return _global_settings
Get the global EnvironmentSettings instance.
Returns:
The global settings (creates one if it doesn't exist)