diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index f861f32..0000000 --- a/docs/api.md +++ /dev/null @@ -1,363 +0,0 @@ -# API Reference - -This API reference covers the core components you'll actually use with AION-1, based on the working implementation. - -## Core Model - -### `aion.AION` - -The main AION model class that provides multimodal astronomical analysis. - -```python -from aion import AION - -class AION(FM): - """ - AION-1 multimodal astronomical foundation model. - - Inherits from FM (4M) architecture and adds astronomical-specific - functionality for processing multiple data modalities. - """ - - @classmethod - def from_pretrained(cls, model_name: str, **kwargs) -> 'AION': - """ - Load a pre-trained AION model from HuggingFace Hub. - - Args: - model_name: HuggingFace model identifier - - 'polymathic-ai/aion-base': 300M parameter model - - Returns: - AION model instance - - Example: - >>> model = AION.from_pretrained('polymathic-ai/aion-base') - >>> model = model.to('cuda').eval() - """ - - def forward( - self, - input_tokens: Dict[str, torch.Tensor], - target_mask: Optional[Dict[str, torch.Tensor]] = None, - num_encoder_tokens: int = 600, - **kwargs - ) -> Dict[str, torch.Tensor]: - """ - Forward pass through the model. - - Args: - input_tokens: Dictionary mapping modality token keys to token tensors - target_mask: Dictionary specifying which tokens to predict - Format: {"tok_z": torch.zeros(batch_size, num_target_tokens)} - num_encoder_tokens: Number of tokens to use in encoder - - Returns: - Dictionary mapping target keys to prediction logits - - Example: - >>> predictions = model( - ... tokens, - ... target_mask={"tok_z": torch.zeros(32, 1)}, - ... num_encoder_tokens=600 - ... ) - >>> redshift_probs = torch.softmax(predictions["tok_z"], dim=-1) - """ - - def encode( - self, - input_tokens: Dict[str, torch.Tensor], - num_encoder_tokens: int = 600 - ) -> torch.Tensor: - """ - Extract embeddings from input tokens. - - Args: - input_tokens: Dictionary of tokenized modality data - num_encoder_tokens: Number of tokens for encoder processing - - Returns: - Encoder embeddings with shape [batch, seq_len, hidden_dim] - - Example: - >>> embeddings = model.encode(tokens, num_encoder_tokens=600) - >>> # Use embeddings for downstream tasks - >>> pooled = embeddings.mean(dim=1) # [batch, hidden_dim] - """ -``` - -## Codec Management - -### `aion.codecs.CodecManager` - -Manages automatic loading and application of modality-specific codecs. - -```python -from aion.codecs import CodecManager - -class CodecManager: - """ - Central manager for encoding/decoding between modalities and tokens. - """ - - def __init__(self, device: str = 'cuda'): - """ - Initialize codec manager. - - Args: - device: Device to load codecs on ('cuda', 'cpu') - - Example: - >>> codec_manager = CodecManager(device='cuda') - """ - - def encode(self, *modalities) -> Dict[str, torch.Tensor]: - """ - Encode modalities into discrete tokens. - - Args: - *modalities: Variable number of modality objects - - Returns: - Dictionary mapping token keys to token tensors - - Example: - >>> tokens = codec_manager.encode(image, spectrum, flux_g) - >>> # Returns: {"tok_image": tensor(...), "tok_spectrum_sdss": tensor(...), "tok_flux_g": tensor(...)} - """ - - def decode( - self, - tokens: Dict[str, torch.Tensor], - modality_class: type, - **metadata - ): - """ - Decode tokens back to modality objects. - - Args: - tokens: Dictionary of token tensors - modality_class: Class of modality to decode (e.g., LegacySurveyImage) - **metadata: Additional metadata required for reconstruction - - Returns: - Reconstructed modality object - - Example: - >>> reconstructed = codec_manager.decode( - ... tokens, - ... LegacySurveyImage, - ... bands=["DES-G", "DES-R", "DES-I", "DES-Z"] - ... ) - """ -``` - -## Modalities - -AION-1 uses a typed modality system to ensure data compatibility and provenance tracking. - -### Base Classes - -```python -from aion.modalities import BaseModality - -class BaseModality: - """Base class for all astronomical modalities.""" - - @property - def token_key(self) -> str: - """Unique identifier for this modality type in the model.""" -``` - -### Image Modalities - -```python -from aion.modalities import LegacySurveyImage, HSCImage - -class LegacySurveyImage(BaseModality): - """ - Legacy Survey multi-band image. - - Attributes: - flux: Image tensor with shape [batch, 4, height, width] for g,r,i,z bands - bands: List of band identifiers (e.g., ['DES-G', 'DES-R', 'DES-I', 'DES-Z']) - """ - - flux: torch.Tensor - bands: List[str] - - @property - def token_key(self) -> str: - return "tok_image" - -class HSCImage(BaseModality): - """ - HSC multi-band image. - - Attributes: - flux: Image tensor with shape [batch, 5, height, width] for g,r,i,z,y bands - bands: List of band identifiers - """ - - flux: torch.Tensor - bands: List[str] - - @property - def token_key(self) -> str: - return "tok_image" -``` - -### Spectrum Modalities - -```python -from aion.modalities import DESISpectrum, SDSSSpectrum - -class DESISpectrum(BaseModality): - """ - DESI spectroscopic observation. - - Attributes: - flux: Flux density array - ivar: Inverse variance array - mask: Boolean mask array - wavelength: Wavelength array in Angstroms - """ - - flux: torch.Tensor - ivar: torch.Tensor - mask: torch.Tensor - wavelength: torch.Tensor - - @property - def token_key(self) -> str: - return "tok_spectrum_desi" - -class SDSSSpectrum(BaseModality): - """SDSS spectroscopic observation.""" - - @property - def token_key(self) -> str: - return "tok_spectrum_sdss" -``` - -### Scalar Modalities - -```python -from aion.modalities import ( - LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ, - Z, GaiaParallax -) - -class LegacySurveyFluxG(BaseModality): - """Legacy Survey g-band flux measurement.""" - - value: torch.Tensor - - @property - def token_key(self) -> str: - return "tok_flux_g" - -class Z(BaseModality): - """Spectroscopic redshift.""" - - value: torch.Tensor - - @property - def token_key(self) -> str: - return "tok_z" -``` - -## Complete Usage Example - -Here's a comprehensive example showing the full workflow: - -```python -import torch -from aion import AION -from aion.codecs import CodecManager -from aion.modalities import ( - LegacySurveyImage, DESISpectrum, - LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ -) - -# 1. Load model and codec manager -model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval() -codec_manager = CodecManager(device='cuda') - -# 2. Prepare data -image = LegacySurveyImage( - flux=torch.tensor(image_data, dtype=torch.float32), - bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] -) - -spectrum = DESISpectrum( - flux=torch.tensor(flux_data), - ivar=torch.tensor(ivar_data), - mask=torch.tensor(mask_data, dtype=torch.bool), - wavelength=torch.tensor(wavelength_data) -) - -flux_g = LegacySurveyFluxG(value=torch.tensor([flux_g_value])) - -# 3. Encode to tokens -tokens = codec_manager.encode(image, spectrum, flux_g) - -# 4. Extract embeddings for downstream tasks -with torch.no_grad(): - embeddings = model.encode(tokens, num_encoder_tokens=600) - pooled_embeddings = embeddings.mean(dim=1) # [batch, hidden_dim] - -# 5. Predict redshift -with torch.no_grad(): - predictions = model( - tokens, - target_mask={"tok_z": torch.zeros(1, 1)}, - num_encoder_tokens=600 - ) - redshift_probs = torch.softmax(predictions["tok_z"][0], dim=-1) - -# 6. Decode tokens back to modalities -reconstructed_image = codec_manager.decode( - tokens, - LegacySurveyImage, - bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] -) -``` - -## Model Variants - -Currently available pre-trained models: - -| Model | Parameters | HuggingFace ID | -|-------|------------|----------------| -| AION-Base | 300M | `polymathic-ai/aion-base` | - -More model variants will be added as they become available. - -## Common Patterns - -### Similarity Search -```python -def compute_similarities(query_tokens, database_tokens, model): - """Compute embedding similarities between query and database.""" - with torch.no_grad(): - query_emb = model.encode(query_tokens).mean(dim=1) - db_embs = model.encode(database_tokens).mean(dim=1) - - from sklearn.metrics.pairwise import cosine_similarity - return cosine_similarity(query_emb.cpu(), db_embs.cpu()) -``` - -### Batch Processing -```python -def process_batch(batch_data, model, codec_manager): - """Process a batch of astronomical objects.""" - batch_tokens = codec_manager.encode(*batch_data) - - with torch.no_grad(): - embeddings = model.encode(batch_tokens, num_encoder_tokens=600) - - return embeddings.mean(dim=1) # Pooled embeddings -``` - -For more examples, see the [Usage Guide](usage.md) and [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb). diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..0942c8b --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,265 @@ +API Reference +============= + +This page provides comprehensive API documentation for all AION components, automatically generated from the source code. + +.. currentmodule:: aion + +Main Model +---------- + +.. automodule:: aion.model + :members: + :undoc-members: + :show-inheritance: + +Modalities +---------- + +The modality system defines data structures for all 39 astronomical data types supported by AION. + +Base Classes +~~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: Modality, Image, Spectrum, Scalar + :undoc-members: + :show-inheritance: + +Image Modalities +~~~~~~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: LegacySurveyImage, HSCImage + :undoc-members: + :show-inheritance: + +Spectrum Modalities +~~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: DESISpectrum, SDSSSpectrum + :undoc-members: + :show-inheritance: + +Catalog Modalities +~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: LegacySurveyCatalog, LegacySurveySegmentationMap + :undoc-members: + :show-inheritance: + +Scalar Modalities +~~~~~~~~~~~~~~~~~ + +Legacy Survey Scalars +^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: aion.modalities + :members: LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ, LegacySurveyFluxW1, LegacySurveyFluxW2, LegacySurveyFluxW3, LegacySurveyFluxW4, LegacySurveyShapeR, LegacySurveyShapeE1, LegacySurveyShapeE2, LegacySurveyEBV + :undoc-members: + :show-inheritance: + +HSC Scalars +~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: HSCAG, HSCAR, HSCAI, HSCAZ, HSCAY, HSCMagG, HSCMagR, HSCMagI, HSCMagZ, HSCMagY, HSCShape11, HSCShape22, HSCShape12 + :undoc-members: + :show-inheritance: + +Gaia Scalars +~~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: GaiaFluxG, GaiaFluxBp, GaiaFluxRp, GaiaParallax, GaiaXpBp, GaiaXpRp + :undoc-members: + :show-inheritance: + +Coordinate Scalars +~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: Ra, Dec, Z + :undoc-members: + :show-inheritance: + +Utility Types +~~~~~~~~~~~~~ + +.. automodule:: aion.modalities + :members: ScalarModalities, ModalityType + :undoc-members: + :show-inheritance: + +Codec System +------------ + +The codec system handles tokenization of different modality types. + +Core Codec Classes +~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.codecs.manager + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.base + :members: + :undoc-members: + :show-inheritance: + +Codec Implementations +~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.codecs.image + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.spectrum + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.catalog + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.scalar_field + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.scalar + :members: + :undoc-members: + :show-inheritance: + +Quantizers +~~~~~~~~~~ + +.. automodule:: aion.codecs.quantizers + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.quantizers.scalar + :members: + :undoc-members: + :show-inheritance: + +4M Transformer +-------------- + +Core transformer architecture and components. + +Main Transformer +~~~~~~~~~~~~~~~~ + +.. automodule:: aion.fourm.fm + :members: + :undoc-members: + :show-inheritance: + +Embedding Layers +~~~~~~~~~~~~~~~~ + +.. automodule:: aion.fourm.encoder_embeddings + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.fourm.decoder_embeddings + :members: + :undoc-members: + :show-inheritance: + +Transformer Components +~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.fourm.fm_utils + :members: + :undoc-members: + :show-inheritance: + +Generation +~~~~~~~~~~ + +.. automodule:: aion.fourm.generate + :members: + :undoc-members: + :show-inheritance: + +LoRA Support +~~~~~~~~~~~~ + +.. automodule:: aion.fourm.lora_utils + :members: + :undoc-members: + :show-inheritance: + +Modality Configuration +~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.fourm.modality_info + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.fourm.modality_transforms + :members: + :undoc-members: + :show-inheritance: + +Codec Modules +------------- + +Specialized neural network modules used in codecs. + +Architecture Components +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.codecs.modules.magvit + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.modules.convnext + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.modules.convblocks + :members: + :undoc-members: + :show-inheritance: + +Specialized Modules +~~~~~~~~~~~~~~~~~~~ + +.. automodule:: aion.codecs.modules.spectrum + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.modules.ema + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: aion.codecs.modules.subsampler + :members: + :undoc-members: + :show-inheritance: + +Configuration and Utilities +---------------------------- + +.. automodule:: aion.codecs.config + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/architecture.md b/docs/architecture.md deleted file mode 100644 index f4062cd..0000000 --- a/docs/architecture.md +++ /dev/null @@ -1,411 +0,0 @@ -# AION-1 Architecture - -This document provides a comprehensive overview of AION-1's architecture, explaining how it achieves unified multimodal understanding of astronomical data through innovative tokenization strategies and transformer-based learning. - -## Overview - -AION-1 employs a two-stage architecture that elegantly handles the complexity of astronomical data: - -1. **Universal Tokenization**: Modality-specific encoders convert heterogeneous astronomical observations into discrete tokens -2. **Multimodal Masked Modeling**: A unified transformer learns cross-modal relationships through masked token prediction - -This design enables AION-1 to process 39 different data modalities from 5 major astronomical surveys, learning from over 200 million objects. - -## Core Design Principles - -### 1. Purely Observational Learning - -Unlike many scientific ML models, AION-1 is trained exclusively on raw observational data without any labels derived from simulations or physical models. This approach provides: - -- **Model-agnostic representations**: Not tied to specific physical assumptions -- **Flexibility**: Can adapt to changing theoretical models -- **Robustness**: Learns patterns directly from data - -### 2. Arbitrary Modality Combinations - -AION-1 can process any subset of its 39 supported modalities without architectural changes: - -- No fixed input requirements -- Graceful handling of missing data -- Dynamic modality fusion - -### 3. Scalable Token-Based Approach - -By converting all data to tokens, AION-1 achieves: - -- Uniform processing across modalities -- Efficient batching and computation -- Natural handling of variable-length inputs - -## Stage 1: Universal Tokenization - -The tokenization stage addresses a fundamental challenge: how to convert diverse astronomical measurements (images, spectra, scalars) into a common representation suitable for transformer processing. - -### Image Tokenization - -AION-1's image tokenizer handles multi-band astronomical images from different surveys with varying: -- Resolution and pixel scales -- Number of channels (4-9 bands) -- Noise characteristics -- Dynamic range - -#### Architecture -``` -# Image tokenizer structure -class ImageCodec: - - Preprocessing: - - Center crop to 96x96 pixels - - Survey-specific rescaling - - Range compression: arcsinh(flux/α) × β - - - Multi-survey projection: - - SubsampledLinear layer (9 → 54 channels) - - Handles variable input bands - - Embeds survey provenance - - - Encoder: MagVit-based architecture - - ResNet backbone with 2 compressions - - Hidden dimensions: 512 - - Bottleneck: 5 dimensions - - - Quantization: Finite Scalar Quantization (FSQ) - - Levels: [8, 5, 5, 5, 5] - - Codebook size: 10,000 -``` - -#### Key Innovations - -1. **Channel Embedding Scheme**: Accommodates images from different surveys with varying band counts in a single model - -2. **Inverse-Variance Weighted Loss**: Leverages known noise properties for optimal reconstruction - ``` - L_NLL = Σ_i 1/2 || Σ_i^(-1/2) (x_i - Decoder(Encoder(x_i))) ||² - ``` - -3. **Survey-Aware Processing**: Maintains provenance information through dedicated embeddings - -### Spectrum Tokenization - -Astronomical spectra present unique challenges: -- Wavelength ranges vary by instrument (3500-10400 Å) -- Resolution differences (R = 1500-5500) -- Orders of magnitude variation in amplitude - -#### Architecture -``` -# Spectrum tokenizer structure -class SpectrumCodec: - - Preprocessing: - - Median normalization - - Log-transform median - - Resampling to latent wavelength grid - - - Latent grid: - - Range: 3500-10462.4 Å - - Resolution: 0.8 Å/pixel - - 8704 pixels total - - - Encoder: ConvNeXt V2 - - Depths: [3, 3, 9, 3] - - Dimensions: [96, 192, 384, 768] - - - Quantization: Lookup-Free Quantization (LFQ) - - Embedding dimension: 10 - - Codebook size: 1024 -``` - -#### Spectral Grid Interpolation - -The tokenizer uses a shared latent wavelength grid, enabling joint processing of spectra from different instruments: - -```python -def to_latent(spectrum, observed_wavelength): - # Interpolate observed spectrum to latent grid - return interp1d(observed_wavelength, spectrum, latent_wavelength) -``` - -### Scalar Tokenization - -Scalar quantities (fluxes, shapes, physical parameters) are tokenized using adaptive quantization based on cumulative distribution functions (CDFs). - -#### Types of Scalar Quantizers - -1. **Linear Quantizer**: For uniformly distributed values -2. **Log Quantizer**: For values spanning orders of magnitude -3. **Reservoir Quantizer**: Learns optimal binning from data -4. **Compressed Quantizer**: Applies transformations before quantization - -Example scalar modalities: -- Photometric fluxes (g, r, i, z bands) -- Shape parameters (ellipticity, radius) -- Physical properties (redshift, extinction) - -### Token Summary at a Glance - -| Modality | Native input tensor shape | Tokens per object | Quantizer type & levels | Codebook size | -|------------------------------------------------|---------------------------|--------------------|-------------------------|---------------| -| Image (HSC / Legacy Survey, 96 × 96 cut-out) | `(B, N_band, 96, 96)` | 144 *(18×18 grid)* | FSQ `[8,5,5,5,5]` | 10 000 | -| Spectrum (SDSS / DESI) | `(B, 2, λ)` *(flux,ivar)* | 64 + 1 norm token | LFQ `dim=10` | 1 024 | -| Scalar quantity (photometry, shapes, etc.) | `(B,)` | 1 per quantity | Reservoir (linear/log) | 256 (default) | -| Catalog (bounding ellipses) | `(B, N_obj, 5)` | ≤100×5 | Composite (per-field) | mixed | - -These numbers correspond to the default configuration used during pre-training (input budget = 256, output budget = 128 tokens). They can be modified at fine-tune time as long as the total token budget is respected. - -### Catalog Tokenization - -Astronomical catalogs contain lists of objects with varying counts per image. AION-1 linearizes these into sequences: - -``` -# Catalog entry: (X, Y, e1, e2, radius) -# Linearization: Sort by distance from center -# Tokenization: Quantize each component separately -``` - -## Stage 2: Multimodal Masked Modeling - -The second stage uses a transformer encoder-decoder architecture to learn relationships between tokens from different modalities. - -### Architecture Details - -``` -class AION(FourM): - # Encoder - - Depth: 12-24 layers (model-dependent) - - Hidden dimension: 768-2048 - - Attention heads: 12-32 - - MLP ratio: 4.0 - - Activation: SwiGLU - - # Decoder - - Same architecture as encoder - - Cross-attention to encoder outputs - - Modality-specific output heads -``` - -### Multimodal Masking Strategy - -AION-1 uses a sophisticated masking strategy that enables learning both within and across modalities: - -1. **Input Token Budget**: Randomly select B tokens across all modalities for input -2. **Output Token Budget**: From remaining tokens, select targets using Beta distribution -3. **Cross-Modal Learning**: Masks ensure model learns to predict any modality from any other - -```python -def mask_multimodal(tokens, num_input=256, num_output=128): - # 1. Select primary modality - primary_mod = random.choice(modalities) - - # 2. Fill input budget - input_tokens = sample_tokens(primary_mod, budget=num_input) - input_tokens += sample_from_other_modalities(remaining_budget) - - # 3. Select outputs (Beta distribution favors fewer tokens) - num_outputs = sample_beta(alpha=0.1, beta=1.0) * num_output - output_tokens = sample_from_remaining(num_outputs) - - return input_tokens, output_tokens -``` - -### Training Objective - -The model optimizes a cross-entropy loss over predicted tokens: - -``` -L = -Σ_t log p(x_t^target | x^observed) -``` - -This simple objective, combined with diverse masking patterns, enables AION-1 to learn rich cross-modal representations. - -## Model Variants - -AION-1 comes in three sizes, each using the same architecture with different dimensions: - -| Model | Parameters | Encoder Layers | Decoder Layers | Hidden Dim | Attention Heads | -|-------|------------|----------------|----------------|------------|-----------------| -| AION-Base | ~300M | 12 | 12 | 768 | 12 | - -> **Note**: Additional model sizes may be released in the future. Current model ID: `polymathic-ai/aion-base` - -All models use: -- SwiGLU activation functions -- No bias terms (except in embeddings) -- QK-Norm for training stability -- Rotary position embeddings - -## Data Flow Through AION-1 - -Here's how data flows through the complete pipeline: - -```{mermaid} -graph TD - A[Raw Astronomical Data] --> B[Modality-Specific Preprocessing] - B --> C[Tokenization] - C --> D[Token Embeddings + Position Encoding] - D --> E[Transformer Encoder] - E --> F[Cross-Modal Representations] - F --> G[Transformer Decoder] - G --> H[Modality-Specific Heads] - H --> I[Predictions/Generations] -``` - -### Example: Processing Galaxy Data - -```python -# 1. Input data -galaxy_data = { - 'image': HSC_5band_image, # (5, 96, 96) - 'spectrum': SDSS_spectrum, # (3800,) - 'photometry': flux_measurements # (8,) -} - -# 2. Tokenization -tokens = { - 'image': image_codec.encode(galaxy_data['image']), # → 144 tokens - 'spectrum': spectrum_codec.encode(galaxy_data['spectrum']), # → 64 tokens - 'photometry': scalar_codec.encode(galaxy_data['photometry']) # → 8 tokens -} - -# 3. Embedding and encoding -embeddings = model.embed_inputs(tokens) -encoder_output = model.encode(embeddings) - -# 4. Cross-modal generation/prediction -predictions = model.decode(encoder_output, target_modalities) -``` - -## Key Architectural Innovations - -### 1. Modality Embeddings with Provenance - -Each token receives two embeddings: -- **Token embedding**: Encodes the discrete token value -- **Modality embedding**: Identifies data type AND source survey - -This allows AION-1 to understand that HSC g-band and SDSS g-band images have different characteristics. - -### 2. Flexible Attention Patterns - -The attention mechanism adapts based on input: -- **Encoder**: Full bidirectional attention across all tokens -- **Decoder**: Causal attention within modalities, cross-attention to encoder - -### 3. Hierarchical Token Organization - -Tokens are organized hierarchically: -- **Spatial tokens**: Preserve 2D structure for images -- **Sequential tokens**: Maintain order for spectra and catalogs -- **Unordered tokens**: For scalar sets - -## Training Infrastructure - -### Dataset Construction - -AION-1's training leverages pairwise associations between surveys: -- HSC images ↔ SDSS spectra -- SDSS spectra ↔ DESI spectra -- Legacy images ↔ Photometry - -This creates a connected graph enabling transitive learning (e.g., HSC → SDSS → DESI). - -### Optimization Details - -- **Optimizer**: AdamW (β₁=0.9, β₂=0.95, weight decay=0.05) -- **Learning rate**: 2e-4 with cosine decay -- **Warmup**: Linear over first 10% of training -- **Batch size**: 8096 (distributed across GPUs) -- **Training steps**: 205,000 -- **Mixed precision**: bfloat16 - -### Computational Requirements - -Training AION-1 requires substantial computational resources: -- **AION-1-B**: 64 H100 GPUs for 1.5 days -- **AION-1-L**: 100 H100 GPUs for 2.5 days -- **AION-1-XL**: 288 H100 GPUs for 3.5 days - -## Emergent Capabilities - -The architecture enables several emergent behaviors: - -### 1. Zero-Shot Cross-Modal Generation -Despite never seeing direct HSC↔DESI associations during training, AION-1 can generate DESI spectra from HSC images through transitive learning. - -### 2. Flexible Conditioning -Any modality subset can condition generation of any other subset, enabling: -- Super-resolution (low-res → high-res spectra) -- Cross-modal translation (images → spectra) -- Imputation (partial → complete observations) - -### 3. Physically Meaningful Representations -The learned embeddings organize objects along interpretable axes: -- Galaxy types (spiral, elliptical, merger) -- Stellar properties (temperature, metallicity) -- Redshift progression - -## Implementation Details - -### Memory Efficiency - -- **Gradient checkpointing**: Trades computation for memory -- **Mixed precision**: bfloat16 for most operations -- **Efficient attention**: Flash Attention 2 implementation - -### Inference Optimization - -- **Token caching**: Reuse encoder outputs for multiple decodings -- **Batch processing**: Process multiple objects simultaneously -- **Quantization**: INT8 inference for deployment - -## Data Provenance & Licensing - -The pre‐training corpus – dubbed *The Multimodal Universe (MMU)* – merges publicly available data products under their respective licences: - -| Survey | Release | Reference | Modalities Used | -|--------|---------|-----------|-----------------| -| Legacy Imaging Survey (DECaLS/BASS/MzLS) | DR10 | Dey et al. 2019 | 4-band images, photometry, catalog scalars | -| Hyper Suprime-Cam (HSC) | PDR3 (Wide+Deep) | Aihara et al. 2019 | 5-band images, photometry, shapes | -| Sloan Digital Sky Survey (SDSS) | DR17 | Eisenstein et al. 2011 | R≈2000 spectra | -| Dark Energy Spectroscopic Instrument (DESI) | EDR | DESI Collab. 2023 | R≈3000 spectra | -| Gaia | DR3 | Gaia Collab. 2022 | Low-res XP spectra, photometry, astrometry | - -All derivative checkpoints released on the Hugging Face Hub are distributed under an MIT licence; users are nevertheless responsible for complying with the upstream survey licences when redistributing raw data. - -## Physical Units & Conventions - -• **Images**: pixel values are calibrated nanomaggies. Exposure time normalisation is survey-specific and automatically handled by the image codec. - -• **Spectra**: flux density in erg s⁻¹ cm⁻² Å⁻¹ (observer frame). Wavelengths are Å, *not* log-λ when inside the model. - -• **Photometry / Scalars**: all fluxes in nanomaggies, magnitudes in the AB system. Ellipticities use SDSS convention *(e₁,e₂)*. - -## Known Limitations & Caveats - -1. No ultraviolet (< 3500 Å) or mid-infrared (> 1 µm) spectral support. -2. HSC chip-edge artefacts occasionally propagate into synthetic spectra – crop images if necessary. -3. The model was trained on **96 × 96 px** cut-outs; objects extending beyond that FOV will be truncated. - -## Citation - -If you use AION-1 in a publication, please cite both the codebase and the accompanying paper: - -```bibtex -@article{Francois2025aion, - title = {AION-1: Omnimodal Foundation Model for Astronomical Sciences}, - author = {LASTNAME, Firstname et al.}, - journal = {arXiv e-prints}, - year = 2025, - archivePrefix = {arXiv}, - eprint = {2406.00000} -} -``` - -## Summary - -AION-1's architecture represents a significant advance in multimodal scientific machine learning: - -1. **Universal tokenization** handles arbitrary astronomical data types -2. **Unified transformer** learns cross-modal relationships -3. **Flexible design** adapts to available observations -4. **Emergent understanding** discovers physical relationships - -This architecture provides a foundation for next-generation astronomical analysis, enabling scientists to leverage all available data for their research. diff --git a/docs/index.md b/docs/index.md index d414bae..4640e26 100644 --- a/docs/index.md +++ b/docs/index.md @@ -26,84 +26,58 @@ Compared to traditional machine learning approaches in Astronomy, AION-1 stands ## 🚀 Quick Start -Getting started with AION-1 is straightforward: +Assuming you have PyTorch installed, you can install AION trivially with: +```bash +pip install polymathic-aion +``` +Then you can load the pretrained model and start analyzing astronomical data: ```python -# Minimal end-to-end example +import torch from aion import AION from aion.codecs import CodecManager -from aion.modalities import (LegacySurveyImage, LegacySurveyFluxG, -LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ) +from aion.modalities import LegacySurveyImage -# 1) Load a pre-trained checkpoint (300 M parameters) -model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval() -codec_manager = CodecManager(device='cuda') # Manages codecs for each modality +# Load model and codec manager +model = AION.from_pretrained('aion-base').to('cuda') # or 'aion-large', 'aion-xlarge' +codec_manager = CodecManager(device='cuda') -# 2) Prepare demo inputs (96×96 g,r,i,z cut-out and photometry) -# Create image modality +# Prepare your astronomical data (example: Legacy Survey image) image = LegacySurveyImage( - flux=data["legacysurvey_image_flux"], - bands=["DES-G", "DES-R", "DES-I", "DES-Z"], + flux=your_image_tensor, # Shape: [batch, 4, height, width] for g,r,i,z bands + bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] ) -# Create flux modalities -g = LegacySurveyFluxG(value=data["legacysurvey_FLUX_G"]) -r = LegacySurveyFluxR(value=data["legacysurvey_FLUX_R"]) -i = LegacySurveyFluxI(value=data["legacysurvey_FLUX_I"]) -z = LegacySurveyFluxZ(value=data["legacysurvey_FLUX_Z"]) +# Encode data to tokens +tokens = codec_manager.encode(image) -# Encode input modalities into tokens -tokens = codec_manager.encode(image, g, r, i, z) +# Option 1: Extract embeddings for downstream tasks +embeddings = model.encode(tokens, num_encoder_tokens=600) -# 3) Generate a redshift distribution from these set of inputs -predictions = model( - tokens, - target_mask={"tok_z": torch.zeros(batch_size, 1)}, - num_encoder_tokens=600 +# Option 2: Generate predictions (e.g., redshift) +from aion.modalities import Z +preds = model( + codec_manager.encode(image), + target_modality=Z, ) -redshift_logits = predictions["tok_z"] # Shape: [batch, sequence, vocab_size] - -# 4) Extract joint embeddings for downstream use -embeddings = model.encode(tokens, num_encoder_tokens=600) # Shape: [batch, seq_len, hidden_dim] ``` -## 📚 Documentation Overview +## 📚 Documentation ```{eval-rst} -.. grid:: 2 2 2 4 +.. grid:: 1 1 1 2 :gutter: 3 - .. grid-item-card:: Installation & Setup - :link: installation.html - :class-card: doc-card - - Environment setup, dependencies, and configuration - - .. grid-item-card:: Model Specifications - :link: architecture.html - :class-card: doc-card - - Deep dive into tokenization, transformers, and trarining data - - .. grid-item-card:: Usage Guide - :link: usage.html - :class-card: doc-card - - Tutorials, examples, and best practices - .. grid-item-card:: API Reference :link: api.html :class-card: doc-card - Complete API documentation and method signatures + Complete API documentation with all classes and methods ``` ```{toctree} :hidden: :maxdepth: 2 -installation -architecture -usage api ``` diff --git a/docs/installation.md b/docs/installation.md deleted file mode 100644 index 6225e4d..0000000 --- a/docs/installation.md +++ /dev/null @@ -1,111 +0,0 @@ -# Installation Guide - -Quick and straightforward installation guide for AION-1. - -## System Requirements - -### Hardware Requirements - -**Minimum (CPU only)**: -- 16 GB RAM -- 20 GB free storage - -**Recommended (GPU)**: -- NVIDIA GPU with 8GB+ VRAM -- 32 GB RAM -- 50 GB free storage - -**For Large-Scale Processing**: -- NVIDIA GPU with 24GB+ VRAM (e.g., RTX 4090, A5000+) -- 64GB+ RAM - -### Software Requirements - -- Python 3.10+ -- CUDA 11.8+ (for GPU acceleration) -- Linux, macOS, or Windows - -## Installation - -### Quick Install (Recommended) - -```bash -# Install PyTorch with CUDA support (adjust CUDA version as needed) -pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 - -# Install AION -pip install aion -``` - -### Alternative: CPU-only Installation - -```bash -# Install CPU-only PyTorch -pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - -# Install AION -pip install aion -``` - -### Development Installation - -```bash -git clone https://github.com/polymathic-ai/aion.git -cd aion -pip install -e ".[torch,dev]" -``` - -## Verification - -Test your installation: - -```python -import torch -from aion import AION -from aion.codecs import CodecManager - -print(f"PyTorch version: {torch.__version__}") -print(f"CUDA available: {torch.cuda.is_available()}") - -# Test model loading (requires internet connection) -try: - model = AION.from_pretrained('polymathic-ai/aion-base') - print("✓ AION model loaded successfully") -except Exception as e: - print(f"✗ Model loading failed: {e}") - -# Test codec manager -try: - codec_manager = CodecManager(device='cuda' if torch.cuda.is_available() else 'cpu') - print("✓ CodecManager initialized successfully") -except Exception as e: - print(f"✗ CodecManager failed: {e}") -``` - -## Troubleshooting - -### Common Issues - -**CUDA out of memory**: -```bash -# Use smaller model or CPU -model = AION.from_pretrained('polymathic-ai/aion-base').to('cpu') -``` - -**HuggingFace connection issues**: -```bash -# Set up HuggingFace cache directory -export HF_HOME=/path/to/cache -``` - -**Import errors**: -```bash -# Reinstall with fresh environment -pip uninstall aion torch -pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 -pip install aion -``` - -## Next Steps - -Once installed, try the [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb) or check the [Usage Guide](usage.md) for examples. diff --git a/docs/usage.md b/docs/usage.md deleted file mode 100644 index 855dbea..0000000 --- a/docs/usage.md +++ /dev/null @@ -1,609 +0,0 @@ -# AION-1 Usage Guide - -This comprehensive guide demonstrates how to use AION-1 for various astronomical analysis tasks, based on the actual working implementation. - -## Table of Contents - -1. [Quick Start](#quick-start) -2. [Loading and Preparing Data](#loading-and-preparing-data) -3. [Basic Workflows](#basic-workflows) -4. [Embedding Extraction](#embedding-extraction) -5. [Similarity Search](#similarity-search) -6. [Property Prediction](#property-prediction) -7. [Performance Tips](#performance-tips) - -## Quick Start - -Here's how to get started with AION-1 in just a few lines: - -```python -import torch -import numpy as np -from aion import AION -from aion.codecs import CodecManager -from aion.modalities import LegacySurveyImage - -# 1. Load model and codec manager -model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval() -codec_manager = CodecManager(device='cuda') - -# 2. Prepare your astronomical data -image = LegacySurveyImage( - flux=torch.tensor(your_image_data, dtype=torch.float32), # Shape: [batch, 4, height, width] - bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] -) - -# 3. Encode to tokens -tokens = codec_manager.encode(image) - -# 4. Extract embeddings for downstream analysis -with torch.no_grad(): - embeddings = model.encode(tokens, num_encoder_tokens=600) - # Shape: [batch, sequence_length, 768] - -# 5. Predict redshift distribution -with torch.no_grad(): - predictions = model( - tokens, - target_mask={"tok_z": torch.zeros(batch_size, 1)}, - num_encoder_tokens=600 - ) - redshift_logits = predictions["tok_z"] - redshift_probs = torch.softmax(redshift_logits, dim=-1) -``` - -## Loading and Preparing Data - -### Working with Images - -AION-1 expects multi-band astronomical images with specific formatting: - -```python -import torch -from astropy.io import fits -from aion.modalities import LegacySurveyImage, HSCImage - -# Example 1: Legacy Survey (4-band: g,r,i,z) -def load_legacy_survey_image(fits_path): - """Load and format Legacy Survey FITS data.""" - with fits.open(fits_path) as hdul: - # Assuming bands are in separate extensions - flux_data = np.array([hdul[i].data for i in range(1, 5)]) # 4 bands - - image = LegacySurveyImage( - flux=torch.tensor(flux_data, dtype=torch.float32), - bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] - ) - return image - -# Example 2: HSC (5-band: g,r,i,z,y) -def load_hsc_image(flux_array): - """Load HSC 5-band image data.""" - image = HSCImage( - flux=torch.tensor(flux_array, dtype=torch.float32), - bands=['HSC-G', 'HSC-R', 'HSC-I', 'HSC-Z', 'HSC-Y'] - ) - return image - -# Note: AION-1 automatically crops/pads images to 96x96 pixels -``` - -### Working with Spectra - -Load and prepare spectroscopic observations: - -```python -from aion.modalities import DESISpectrum, SDSSSpectrum - -def load_desi_spectrum(flux, ivar, mask, wavelength): - """Load DESI spectrum data.""" - spectrum = DESISpectrum( - flux=torch.tensor(flux, dtype=torch.float32), - ivar=torch.tensor(ivar, dtype=torch.float32), - mask=torch.tensor(mask, dtype=torch.bool), - wavelength=torch.tensor(wavelength, dtype=torch.float32) - ) - return spectrum - -def load_sdss_spectrum_from_fits(fits_path): - """Load SDSS spectrum from FITS file.""" - with fits.open(fits_path) as hdul: - data = hdul[1].data - wavelength = 10**data['loglam'] # Convert from log wavelength - flux = data['flux'] - ivar = data['ivar'] - - # Create mask for bad pixels - mask = (ivar > 0) & (flux > 0) - - spectrum = SDSSSpectrum( - flux=torch.tensor(flux, dtype=torch.float32), - ivar=torch.tensor(ivar, dtype=torch.float32), - mask=torch.tensor(mask, dtype=torch.bool), - wavelength=torch.tensor(wavelength, dtype=torch.float32) - ) - return spectrum -``` - -### Working with Photometric Data - -Prepare scalar measurements like fluxes and shape parameters: - -```python -from aion.modalities import ( - LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ, - Z, GaiaParallax -) - -def create_photometry_modalities(catalog_data): - """Create modalities from catalog measurements.""" - modalities = [] - - # Photometric fluxes - if 'flux_g' in catalog_data: - modalities.append(LegacySurveyFluxG( - value=torch.tensor(catalog_data['flux_g'], dtype=torch.float32) - )) - - if 'flux_r' in catalog_data: - modalities.append(LegacySurveyFluxR( - value=torch.tensor(catalog_data['flux_r'], dtype=torch.float32) - )) - - # Redshift - if 'redshift' in catalog_data: - modalities.append(Z( - value=torch.tensor(catalog_data['redshift'], dtype=torch.float32) - )) - - return modalities -``` - -## Basic Workflows - -### Workflow 1: Embedding Extraction - -Extract learned representations for downstream machine learning: - -```python -def extract_galaxy_embeddings(data_list, model, codec_manager): - """Extract embeddings from a list of galaxy observations.""" - all_embeddings = [] - - # Process in batches for efficiency - batch_size = 32 - for i in range(0, len(data_list), batch_size): - batch = data_list[i:i + batch_size] - - # Encode all modalities in the batch - batch_tokens = codec_manager.encode(*batch) - - # Extract embeddings - with torch.no_grad(): - embeddings = model.encode(batch_tokens, num_encoder_tokens=600) - # Pool over sequence dimension - pooled = embeddings.mean(dim=1) # [batch, 768] - - all_embeddings.append(pooled.cpu().numpy()) - - return np.vstack(all_embeddings) - -# Usage example -galaxy_embeddings = extract_galaxy_embeddings( - [image1, image2, image3, ...], - model, - codec_manager -) -``` - -### Workflow 2: Redshift Prediction - -Predict redshift distributions from various input modalities: - -```python -def predict_redshift_distribution(inputs, model, codec_manager): - """Predict redshift probability distribution.""" - # Encode inputs - tokens = codec_manager.encode(*inputs) - - # Predict redshift - with torch.no_grad(): - predictions = model( - tokens, - target_mask={"tok_z": torch.zeros(len(inputs), 1)}, - num_encoder_tokens=600 - ) - - # Convert to probabilities - redshift_logits = predictions["tok_z"] - redshift_probs = torch.softmax(redshift_logits, dim=-1) - - return redshift_probs - -# Example: Predict from photometry -redshift_dist = predict_redshift_distribution( - [flux_g, flux_r, flux_i, flux_z], - model, - codec_manager -) -``` - -### Workflow 3: Reconstruction - -Reconstruct modalities through the encode-decode process: - -```python -def reconstruct_modality(original_modality, model, codec_manager, modality_class, **metadata): - """Reconstruct a modality through encode-decode cycle.""" - # Encode original - tokens = codec_manager.encode(original_modality) - - # Decode back - reconstructed = codec_manager.decode( - tokens, - modality_class, - **metadata - ) - - return reconstructed - -# Example: Reconstruct image -reconstructed_image = reconstruct_modality( - original_image, - model, - codec_manager, - LegacySurveyImage, - bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] -) -``` - -## Embedding Extraction - -### Basic Embedding Extraction - -```python -def get_embeddings(modalities, model, codec_manager, pooling='mean'): - """Extract embeddings with different pooling strategies.""" - tokens = codec_manager.encode(*modalities) - - with torch.no_grad(): - embeddings = model.encode(tokens, num_encoder_tokens=600) - - # Apply pooling - if pooling == 'mean': - return embeddings.mean(dim=1) - elif pooling == 'max': - return embeddings.max(dim=1)[0] - elif pooling == 'cls': - return embeddings[:, 0] # First token - else: - return embeddings # Return full sequence - -# Usage -embeddings = get_embeddings([image, spectrum], model, codec_manager) -``` - -### Multi-Modal Embeddings - -Combine embeddings from different modalities: - -```python -def get_multimodal_embeddings(image, spectrum, photometry, model, codec_manager): - """Extract embeddings from multiple modality types.""" - - # Get embeddings from each modality type - image_tokens = codec_manager.encode(image) - spectrum_tokens = codec_manager.encode(spectrum) - photo_tokens = codec_manager.encode(*photometry) - - embeddings = {} - - with torch.no_grad(): - # Image embeddings - img_emb = model.encode(image_tokens, num_encoder_tokens=300) - embeddings['image'] = img_emb.mean(dim=1) - - # Spectrum embeddings - spec_emb = model.encode(spectrum_tokens, num_encoder_tokens=300) - embeddings['spectrum'] = spec_emb.mean(dim=1) - - # Combined embeddings - all_tokens = {**image_tokens, **spectrum_tokens, **photo_tokens} - combined_emb = model.encode(all_tokens, num_encoder_tokens=900) - embeddings['combined'] = combined_emb.mean(dim=1) - - return embeddings -``` - -## Similarity Search - -Implement similarity search using AION embeddings: - -```python -from sklearn.metrics.pairwise import cosine_similarity -from sklearn.neighbors import NearestNeighbors - -class AIONSimilaritySearch: - def __init__(self, model, codec_manager): - self.model = model - self.codec_manager = codec_manager - self.database_embeddings = [] - self.database_objects = [] - self.index = None - - def add_objects(self, objects): - """Add objects to the search database.""" - for obj in objects: - # Extract embedding - tokens = self.codec_manager.encode(*obj['modalities']) - with torch.no_grad(): - emb = self.model.encode(tokens, num_encoder_tokens=600) - emb = emb.mean(dim=1).cpu().numpy() - - self.database_embeddings.append(emb) - self.database_objects.append(obj) - - # Build search index - if self.database_embeddings: - embeddings_matrix = np.vstack(self.database_embeddings) - self.index = NearestNeighbors(n_neighbors=10, metric='cosine') - self.index.fit(embeddings_matrix) - - def search(self, query_modalities, k=5): - """Search for similar objects.""" - # Get query embedding - tokens = self.codec_manager.encode(*query_modalities) - with torch.no_grad(): - query_emb = self.model.encode(tokens, num_encoder_tokens=600) - query_emb = query_emb.mean(dim=1).cpu().numpy() - - # Find nearest neighbors - distances, indices = self.index.kneighbors(query_emb, n_neighbors=k) - - results = [] - for i, idx in enumerate(indices[0]): - results.append({ - 'object': self.database_objects[idx], - 'similarity': 1 - distances[0][i], # Convert distance to similarity - 'rank': i + 1 - }) - - return results - -# Usage example -searcher = AIONSimilaritySearch(model, codec_manager) - -# Add objects to database -database_objects = [ - {'modalities': [image1, spectrum1], 'metadata': {'id': 'galaxy_1'}}, - {'modalities': [image2, spectrum2], 'metadata': {'id': 'galaxy_2'}}, - # ... more objects -] -searcher.add_objects(database_objects) - -# Search for similar objects -query_galaxy = [query_image, query_spectrum] -similar_objects = searcher.search(query_galaxy, k=10) - -print(f"Found {len(similar_objects)} similar objects:") -for result in similar_objects: - print(f"Rank {result['rank']}: {result['object']['metadata']['id']} " - f"(similarity: {result['similarity']:.3f})") -``` - -## Property Prediction - -Use AION embeddings for various prediction tasks: - -### Redshift Estimation with k-NN - -```python -from sklearn.neighbors import KNeighborsRegressor -from sklearn.model_selection import train_test_split -from sklearn.metrics import mean_absolute_error, r2_score - -def train_redshift_predictor(galaxies_with_redshifts, model, codec_manager): - """Train a k-NN regressor for redshift prediction.""" - - # Extract embeddings and targets - embeddings = [] - redshifts = [] - - for galaxy in galaxies_with_redshifts: - tokens = codec_manager.encode(*galaxy['modalities']) - with torch.no_grad(): - emb = model.encode(tokens, num_encoder_tokens=600) - emb = emb.mean(dim=1).cpu().numpy() - - embeddings.append(emb[0]) # Remove batch dimension - redshifts.append(galaxy['redshift']) - - X = np.array(embeddings) - y = np.array(redshifts) - - # Split data - X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.2, random_state=42 - ) - - # Train k-NN regressor - knn = KNeighborsRegressor(n_neighbors=5) - knn.fit(X_train, y_train) - - # Evaluate - y_pred = knn.predict(X_test) - mae = mean_absolute_error(y_test, y_pred) - r2 = r2_score(y_test, y_pred) - - print(f"Redshift prediction - MAE: {mae:.4f}, R²: {r2:.4f}") - - return knn - -def predict_redshift(new_galaxy, trained_model, model, codec_manager): - """Predict redshift for a new galaxy.""" - tokens = codec_manager.encode(*new_galaxy) - with torch.no_grad(): - emb = model.encode(tokens, num_encoder_tokens=600) - emb = emb.mean(dim=1).cpu().numpy() - - predicted_z = trained_model.predict(emb)[0] - return predicted_z -``` - -### Stellar Mass Prediction - -```python -from sklearn.ensemble import RandomForestRegressor - -def train_stellar_mass_predictor(galaxies_with_masses, model, codec_manager): - """Train predictor for stellar mass estimation.""" - - # Similar to redshift prediction but for stellar mass - embeddings = [] - masses = [] - - for galaxy in galaxies_with_masses: - tokens = codec_manager.encode(*galaxy['modalities']) - with torch.no_grad(): - emb = model.encode(tokens, num_encoder_tokens=600) - emb = emb.mean(dim=1).cpu().numpy() - - embeddings.append(emb[0]) - masses.append(np.log10(galaxy['stellar_mass'])) # Log stellar mass - - X = np.array(embeddings) - y = np.array(masses) - - # Train Random Forest - rf = RandomForestRegressor(n_estimators=100, random_state=42) - rf.fit(X, y) - - return rf -``` - -## Performance Tips - -### Batch Processing - -Process multiple objects efficiently: - -```python -def process_batch_efficiently(object_list, model, codec_manager, batch_size=32): - """Process objects in batches for better GPU utilization.""" - results = [] - - for i in range(0, len(object_list), batch_size): - batch = object_list[i:i + batch_size] - - # Group by modality type for efficient encoding - images = [obj for obj in batch if 'image' in obj] - spectra = [obj for obj in batch if 'spectrum' in obj] - - batch_results = [] - - with torch.no_grad(): - # Process images - if images: - image_batch = [obj['image'] for obj in images] - tokens = codec_manager.encode(*image_batch) - embeddings = model.encode(tokens, num_encoder_tokens=600) - batch_results.extend(embeddings.mean(dim=1).cpu().numpy()) - - # Process spectra - if spectra: - spectrum_batch = [obj['spectrum'] for obj in spectra] - tokens = codec_manager.encode(*spectrum_batch) - embeddings = model.encode(tokens, num_encoder_tokens=300) - batch_results.extend(embeddings.mean(dim=1).cpu().numpy()) - - results.extend(batch_results) - - return results -``` - -### Memory Management - -Handle large datasets with limited GPU memory: - -```python -def process_large_dataset(dataset, model, codec_manager, max_batch_size=16): - """Process large datasets with automatic memory management.""" - import gc - - current_batch_size = max_batch_size - results = [] - - i = 0 - while i < len(dataset): - try: - batch = dataset[i:i + current_batch_size] - - # Process batch - batch_tokens = codec_manager.encode(*batch) - with torch.no_grad(): - embeddings = model.encode(batch_tokens, num_encoder_tokens=600) - results.append(embeddings.mean(dim=1).cpu()) - - i += current_batch_size - - except torch.cuda.OutOfMemoryError: - # Clear memory and reduce batch size - torch.cuda.empty_cache() - gc.collect() - current_batch_size = max(1, current_batch_size // 2) - print(f"Reduced batch size to {current_batch_size}") - - if current_batch_size == 0: - raise RuntimeError("Cannot process even single example") - - return torch.cat(results, dim=0) -``` - -### Using Mixed Precision - -Speed up inference with automatic mixed precision: - -```python -def extract_embeddings_amp(modalities, model, codec_manager): - """Extract embeddings using automatic mixed precision.""" - from torch.cuda.amp import autocast - - tokens = codec_manager.encode(*modalities) - - with torch.no_grad(): - with autocast(): - embeddings = model.encode(tokens, num_encoder_tokens=600) - - return embeddings.float() # Convert back to float32 -``` - -## Best Practices - -1. **Always use `.eval()` mode** for inference to disable dropout and batch norm updates -2. **Use `torch.no_grad()`** to disable gradient computation and save memory -3. **Process in batches** when possible for better GPU utilization -4. **Pool embeddings appropriately** - mean pooling works well for most tasks -5. **Use consistent device placement** - ensure all tensors are on the same device -6. **Clear GPU cache** periodically when processing large datasets - -## Troubleshooting - -### Common Issues - -1. **CUDA out of memory**: Reduce batch size or use gradient checkpointing -2. **Slow processing**: Ensure data is on GPU and use batch processing -3. **Shape mismatches**: Check that tensor dimensions match expected format -4. **Device errors**: Ensure model, data, and codec_manager are on same device - -### Debug Mode - -```python -def debug_tokens(tokens, codec_manager): - """Debug token shapes and contents.""" - print("Token summary:") - for key, tensor in tokens.items(): - print(f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}") - print(f" range: [{tensor.min().item():.2f}, {tensor.max().item():.2f}]") -``` - -For more advanced examples and the latest updates, see the [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb).