"""
Smart Plant Infrastructure for antupy
This module provides intelligent component management for Plant classes,
enabling automatic caching and selective invalidation of components based
on parameter dependencies.
"""
from __future__ import annotations
from typing import Any, TypeVar, Callable, TYPE_CHECKING
from dataclasses import dataclass, field
import inspect
import hashlib
from antupy import Var
# Import only what we need to avoid circular imports
if TYPE_CHECKING:
from antupy import Var, SimulationOutput
# Type variable for component decorator
T = TypeVar('T')
[docs]
@dataclass
class Plant():
"""
Plant base class with component management.
Features:
- Automatic component recreation on parameter changes
- Dependency tracking between parameters and components
- Efficient caching with hash-based invalidation
- Dynamic parameter introspection
Usage:
class MyPlant(Plant):
# Define parameters as usual
param1: Var = Var(10.0, "m")
param2: Var = Var(5.0, "kg")
@property
def my_component(self) -> SomeComponentClass:
return component(SomeComponentClass(
param1=constraint(self.param1),
param2=constraint(self.param2)
))
@property
def complex_component(self) -> AnotherComponentClass:
return component(AnotherComponentClass(
param1=constraint(self.param1),
computed_param=derived(self._compute_something, self.param1)
))
"""
# Base Plant attributes (preserved for Protocol compatibility)
out: SimulationOutput = field(default_factory=dict)
constraints: list[tuple[str, ...]] = field(default_factory=list)
# Component registry: maps component names to their classes and dependencies
_component_dependencies: dict[str, set[str]] = field(
default_factory=dict, init=False, repr=False
)
# Component instances cache
_component_cache: dict[str, Any] = field(
default_factory=dict, init=False, repr=False
)
# Parameter hash cache for change detection
_param_hash_cache: dict[str, str] = field(
default_factory=dict, init=False, repr=False
)
# Track current property context for component() function
_current_property_context: str | None = field(default=None, init=False, repr=False)
def __post_init__(self):
# Initialize Plant-specific attributes (preserved for Protocol compatibility)
if not hasattr(self, 'out'):
self.out: SimulationOutput = {}
# Note: dependency discovery will happen lazily when first component is accessed
def run_simulation(self, verbose: bool = False) -> SimulationOutput:
"""Run simulation method (preserved for Protocol compatibility)."""
# This method should be implemented by subclasses
# Default implementation returns current out dict
return self.out
def __getattribute__(self, name):
"""Override to track property access for component() function context."""
# Get the attribute normally
attr = object.__getattribute__(self, name)
# Check if this is a property access and we're accessing a component property
cls = object.__getattribute__(self, '__class__')
class_attr = getattr(cls, name, None)
if isinstance(class_attr, property):
# Set context so component() function knows which property is being accessed
object.__setattr__(self, '_current_property_context', name)
try:
# Call the property getter - this should be the actual property call
if class_attr.fget is not None:
result = class_attr.fget(self) # Call property function directly
return result
else:
return attr
finally:
# Clear context after property access
object.__setattr__(self, '_current_property_context', None)
return attr
def _ensure_dependencies_discovered(self):
"""Ensure component dependencies have been discovered (lazy initialization)."""
if not self._component_dependencies:
self._discover_component_dependencies()
def _discover_component_dependencies(self):
"""Auto-discover which parameters affect which components."""
# Clear any existing dependencies first
self._component_dependencies.clear()
# Force component access to trigger dependency discovery
for name in dir(self):
attr = getattr(type(self), name, None)
if isinstance(attr, property):
try:
# Access the property to trigger component() calls and dependency registration
# This will populate _component_dependencies through constraint() and derived() calls
_ = getattr(self, name)
except:
# If property access fails, skip this component
continue
def _get_component_parameters(self, component_class) -> set[str]:
"""Extract parameter names from component __init__ signature."""
sig = inspect.signature(component_class.__init__)
return {
param_name for param_name, param in sig.parameters.items()
if param_name != 'self'
}
def _get_computed_parameters(self, component_name: str) -> set[str]:
"""Find parameters used in computed parameter methods."""
computed_params = set()
# Look for _get_* methods that this component might use
for attr_name in dir(self):
if attr_name.startswith('_get_'):
computed_params.add(attr_name[4:]) # Remove '_get_' prefix
return computed_params
def _compute_params_hash(self, param_names: set[str]) -> str:
"""Compute hash of current parameter values."""
param_values = []
for param_name in sorted(param_names): # Sort for consistent hashing
if hasattr(self, param_name):
value = getattr(self, param_name)
if isinstance(value, Var): # Var object
param_values.append(f"{param_name}={value.gv()}|{value.unit}")
elif isinstance(value, str):
param_values.append(f"{param_name}={value}")
else:
raise TypeError(f"Unsupported parameter type for hashing: {type(value)}")
param_string = "|".join(param_values)
return hashlib.md5(param_string.encode()).hexdigest()
def _invalidate_affected_components(self, changed_params: set[str]):
"""
Invalidate only components affected by changed parameters.
This is the key method that should be called by Parametric._update_parameters()
instead of __post_init__() to enable smart component caching.
Args:
changed_params: Set of parameter names that have changed
"""
self._ensure_dependencies_discovered()
for component_name, dependencies in self._component_dependencies.items():
if changed_params & dependencies: # Intersection check
# Clear component from cache - will be recreated on next access
if component_name in self._component_cache:
del self._component_cache[component_name]
if component_name in self._param_hash_cache:
del self._param_hash_cache[component_name]
def _needs_component_recreation(self, component_name: str) -> bool:
"""Check if component needs recreation."""
# Ensure dependencies are discovered
self._ensure_dependencies_discovered()
if component_name not in self._component_cache:
return True
if component_name not in self._component_dependencies:
return True
# Check parameter hash
dependencies = self._component_dependencies[component_name]
current_hash = self._compute_params_hash(dependencies)
cached_hash = self._param_hash_cache.get(component_name)
return current_hash != cached_hash
def _update_component_hash(self, component_name: str):
"""Update the parameter hash for a component."""
if component_name in self._component_dependencies:
dependencies = self._component_dependencies[component_name]
current_hash = self._compute_params_hash(dependencies)
self._param_hash_cache[component_name] = current_hash
def _build_component_kwargs(self, component_class):
"""Build kwargs for component, with smart parameter mapping."""
sig = inspect.signature(component_class.__init__)
kwargs = {}
for param_name, param in sig.parameters.items():
if param_name == 'self':
continue
if hasattr(self, param_name):
kwargs[param_name] = getattr(self, param_name)
elif hasattr(self, f'_get_{param_name}'):
# Call computed parameter method
kwargs[param_name] = getattr(self, f'_get_{param_name}')()
elif param.default != inspect.Parameter.empty:
# Use default value if available
kwargs[param_name] = param.default
else:
# Required parameter not found
raise TypeError(f"Required parameter '{param_name}' not found for component {component_class.__name__}")
return kwargs
def get_component_cache_stats(self) -> dict[str, Any]:
"""Get statistics about component cache usage."""
return {
'cached_components': list(self._component_cache.keys()),
'component_dependencies': dict(self._component_dependencies),
'cache_size': len(self._component_cache),
'total_dependencies': sum(len(deps) for deps in self._component_dependencies.values())
}
def component(instance: T, component_name: str | None = None) -> T:
"""
Cache-aware component function for use within @property methods.
This function works alongside constraint() and derived() to provide:
1. Component instance pass-through (returns the input instance)
2. Smart caching based on parameter dependencies
3. Automatic cache invalidation when dependencies change
4. Perfect type inference (since it's used within @property)
Usage:
@property
def HSF(self) -> SolarField:
return component(SolarField(
zf=constraint(self.zf),
file_SF=derived(self._file_SF, self.zf)
))
Args:
instance: The component instance to cache and return
component_name: Optional component name (auto-detected from context)
Returns:
The same instance, but now tracked for smart caching
Examples:
@property
def my_component(self) -> SomeComponent:
return component(SomeComponent(
param1=constraint(self.param1),
param2=constraint(self.param2)
))
"""
# Get the current component context (which property is being accessed)
context = _get_current_component_context()
if context is None:
# If no context, just return the instance (no caching available)
return instance
plant_instance, detected_name = context
final_name = component_name or detected_name
# Check if we should use cached version
if _should_use_cached_component(plant_instance, final_name):
cached = _get_cached_component(plant_instance, final_name)
if cached is not None:
return cached
# Cache this new instance and return it
_cache_component(plant_instance, final_name, instance)
return instance
# ============================================================================
# constraint AND derived FUNCTIONS FOR component function
# ============================================================================
def constraint(value: T, param_name: str | None = None) -> T:
"""
Mark a parameter as a direct dependency and return it unchanged.
This function serves two purposes:
1. Returns the input value unchanged (pass-through)
2. Registers the parameter as a dependency for smart caching
Args:
value: The parameter value to mark as a constraint
param_name: Optional parameter name (auto-detected if not provided)
Returns:
The input value unchanged
Example:
@component
def HSF(self) -> SolarField:
zf = constraint(self.zf) # Mark zf as dependency
return SolarField(zf=zf)
"""
# Auto-detect parameter name if not provided
if param_name is None:
try:
# Get the calling frame to extract variable name
frame = inspect.currentframe()
if frame is not None and frame.f_back is not None:
frame = frame.f_back
# Look for pattern like "self.param_name" in the calling code
import re
frame_info = inspect.getframeinfo(frame)
if frame_info.code_context and len(frame_info.code_context) > 0:
source_line = frame_info.code_context[0]
# Try to extract parameter name from "constraint(self.param_name)"
match = re.search(r'constraint\(self\.([a-zA-Z_][a-zA-Z0-9_]*)', source_line)
if match:
param_name = match.group(1)
else:
# Fallback: try to extract from assignment like "zf = constraint(self.zf)"
match = re.search(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*constraint', source_line)
if match:
param_name = match.group(1)
except:
pass # If auto-detection fails, no dependency tracking
# Register dependency if we have a parameter name
if param_name:
_register_dependency(param_name)
return value
def derived(callable_func: Callable, *tracked_vars: Any, param_name: str | None = None) -> Any:
"""
Execute a callable with tracked variables and register dependencies.
This function:
1. Executes the callable with the tracked variables as arguments
2. Registers dependencies on the tracked variables for smart caching
3. Returns the computed result
Args:
callable_func: Function to execute for computing the derived value
*tracked_vars: Variables that this derived value depends on
param_name: Optional parameter name (auto-detected if not provided)
Returns:
Result of calling callable_func(*tracked_vars)
Example:
@component
def HSF(self) -> SolarField:
def _file_SF(zf):
return f'dataset_{zf.gv("m"):.0f}m.csv'
zf = constraint(self.zf)
file_SF = derived(_file_SF, zf) # Depends on zf
return SolarField(zf=zf, file_SF=file_SF)
"""
# Auto-detect parameter name if not provided
if param_name is None:
try:
# Get the calling frame to extract variable name
frame = inspect.currentframe()
if frame is not None and frame.f_back is not None:
frame = frame.f_back
# Look for assignment pattern like "file_SF = derived(...)"
import re
frame_info = inspect.getframeinfo(frame)
if frame_info.code_context and len(frame_info.code_context) > 0:
source_line = frame_info.code_context[0]
match = re.search(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*derived', source_line)
if match:
param_name = match.group(1)
except:
pass # If auto-detection fails, no dependency tracking
# Register this as a derived parameter dependency
if param_name:
_register_dependency(f"derived_{param_name}")
# For each tracked variable, try to register its parameter name as a dependency
context = _get_current_component_context()
if context is not None:
plant_instance, component_name = context
# Use source code inspection to find parameter names in constraint() calls
# This avoids the expensive dir() loop that triggers all component properties
try:
frame = inspect.currentframe()
if frame is not None and frame.f_back is not None:
caller_frame = frame.f_back
frame_info = inspect.getframeinfo(caller_frame)
if frame_info.code_context:
# Look through multiple lines if available (for multi-line derived calls)
source_lines = frame_info.code_context
source_text = ''.join(source_lines)
# Extract constraint parameter names using regex
import re
# Pattern matches: constraint(self.param_name)
constraint_matches = re.findall(r'constraint\(self\.([a-zA-Z_][a-zA-Z0-9_]*)', source_text)
# Register each found parameter name as a dependency
for param_name in constraint_matches:
if param_name: # Ensure param_name is not None or empty
_register_dependency(param_name)
except Exception:
# If source inspection fails, silently continue without dependency tracking
# This is better than the expensive dir() loop
pass
# Execute the callable and return result
try:
return callable_func(*tracked_vars)
except Exception as e:
raise ValueError(f"Error executing derived parameter calculation: {e}")
# Thread-local storage for tracking current component context
from threading import local
_component_context = local()
def _get_current_component_context() -> tuple[object, str] | None:
"""Get the current component being built (plant instance, component name)."""
# Then try to get from calling frame (for component() function)
try:
frame = inspect.currentframe()
# Walk up the call stack to find the property access
current_frame = frame
for i in range(5): # Limit search depth
if current_frame is None:
break
frame_locals = current_frame.f_locals
if 'self' in frame_locals:
plant_instance = frame_locals['self']
if hasattr(plant_instance, '_current_property_context'):
component_name = getattr(plant_instance, '_current_property_context', None)
if component_name:
return (plant_instance, component_name)
current_frame = current_frame.f_back
except Exception:
pass # Fallback gracefully if frame inspection fails
return None
def _should_use_cached_component(plant_instance, component_name: str) -> bool:
"""Check if cached component is still valid."""
if not hasattr(plant_instance, '_component_cache'):
return False
if component_name not in plant_instance._component_cache:
return False
# Check if dependencies changed (using existing hash system)
if hasattr(plant_instance, '_component_dependencies'):
dependencies = plant_instance._component_dependencies.get(component_name, set())
if dependencies:
current_hash = plant_instance._compute_params_hash(dependencies)
cached_hash = plant_instance._param_hash_cache.get(component_name)
return current_hash == cached_hash
# If no dependencies tracked yet, assume valid (will be tracked on first access)
return True
def _get_cached_component(plant_instance, component_name: str):
"""Get cached component if available."""
if hasattr(plant_instance, '_component_cache'):
return plant_instance._component_cache.get(component_name)
return None
def _cache_component(plant_instance, component_name: str, instance):
"""Cache a component instance."""
if not hasattr(plant_instance, '_component_cache'):
plant_instance._component_cache = {}
if not hasattr(plant_instance, '_param_hash_cache'):
plant_instance._param_hash_cache = {}
plant_instance._component_cache[component_name] = instance
# Update hash if dependencies are known
if hasattr(plant_instance, '_component_dependencies'):
dependencies = plant_instance._component_dependencies.get(component_name, set())
if dependencies:
current_hash = plant_instance._compute_params_hash(dependencies)
plant_instance._param_hash_cache[component_name] = current_hash
def _set_component_context(plant_instance: object, component_name: str):
"""Set the current component context for dependency tracking."""
_component_context.current = (plant_instance, component_name)
def _clear_component_context():
"""Clear the current component context."""
_component_context.current = None
def _register_dependency(param_name: str) -> None:
"""Register a dependency for the current component being built."""
context = _get_current_component_context()
if context is None:
# If no component context, this is being called outside component creation
# This is OK - dependency tracking is optional
return
plant_instance, component_name = context
# Initialize dependencies if not present
if not hasattr(plant_instance, '_component_dependencies'):
setattr(plant_instance, '_component_dependencies', {})
dependencies = getattr(plant_instance, '_component_dependencies')
if component_name not in dependencies:
dependencies[component_name] = set()
# Add this parameter as a dependency
dependencies[component_name].add(param_name)
return