Source code for pasted._generator

"""
pasted._generator
=================
High-level API:

- :class:`Structure`          — dataclass holding one generated structure.
- :class:`StructureGenerator` — stateful generator (class API).
- :func:`generate`            — convenience functional wrapper.

Internal design
---------------
The generation loop lives in a single private method,
:meth:`StructureGenerator._stream_with_stats`, which returns a
``(structures_iterator, stats_dict)`` pair.  The public methods
:meth:`~StructureGenerator.stream` and
:meth:`~StructureGenerator.generate` are thin wrappers around it:

- ``stream()`` forwards the iterator directly to the caller.
- ``generate()`` exhausts the iterator, reads the populated *stats_dict*,
  and wraps everything in a :class:`GenerationResult`.

This two-layer design eliminates the hidden coupling that previously
existed via the ``_last_run_stats`` instance variable: interrupting
``stream()`` early can no longer leave stale counters for a subsequent
``generate()`` call.

Verbose log output is routed through three focused helpers
(:meth:`~StructureGenerator._log_filter_header`,
:meth:`~StructureGenerator._log_sample_result`,
:meth:`~StructureGenerator._log_summary`) so that the generation loop
itself contains only placement logic.
"""

from __future__ import annotations

import math
import random
import sys
import warnings
from collections import Counter
from collections.abc import Iterator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import numpy as np

from ._atoms import (
    _Z_TO_SYM,
    ATOMIC_NUMBERS,
    _cov_radius_ang,
    default_element_pool,
    parse_element_spec,
    parse_filter,
    validate_charge_mult,
)
from ._config import GeneratorConfig
from ._io import _fmt, format_xyz, parse_xyz
from ._metrics import compute_all_metrics, passes_filters
from ._placement import (
    _H_COV_RADIUS,
    _PACKING_HARD,
    _PACKING_WARN,
    Vec3,
    _affine_move,
    place_chain,
    place_gas,
    place_maxent,
    place_shell,
    relax_positions,
)

# ---------------------------------------------------------------------------
# Structure dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass class Structure: """A single generated atomic structure with its computed disorder metrics. Attributes ---------- atoms: Element symbols, one per atom. positions: Cartesian coordinates in Å, one ``(x, y, z)`` tuple per atom. charge: Total system charge. mult: Spin multiplicity 2S+1. metrics: Computed disorder metrics (see :data:`pasted._atoms.ALL_METRICS`). mode: Placement mode used (``"gas"``, ``"chain"``, ``"shell"``, ``"maxent"``, or ``"opt_<method>"`` for optimizer results). sample_index: 1-based index within the batch of structures that passed filters. center_sym: Element symbol of the shell center atom (shell mode only). seed: Random seed used for generation (``None`` if unseeded). Properties ---------- comp: Read-only composition string derived from :attr:`atoms`, sorted in alphabetical order by element symbol, e.g. ``'C5N2O3'``. Computed on access; not stored as a field. .. note:: The sort order is **alphabetical** (``sorted()`` on symbol strings), not Hill order (C first, H second, then alphabetical). Structures containing only C, H, N, O will look identical to Hill order, but others — e.g. ``['Na', 'C', 'H']`` → ``'CH2Na'`` — differ. Examples -------- Access the composition string directly:: s = generate(n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:8", elements="6,7,8", n_samples=5, seed=0)[0] print(s.comp) # e.g. 'C4N3O3' print(repr(s)) # Structure(n=10, comp='C4N3O3', mode='gas', H_total=…) """ atoms: list[str] positions: list[Vec3] charge: int mult: int metrics: dict[str, float] mode: str sample_index: int = 0 center_sym: str | None = None seed: int | None = None # ------------------------------------------------------------------ # # XYZ output # # ------------------------------------------------------------------ #
[docs] def to_xyz(self, prefix: str = "") -> str: """Serialise to extended XYZ format. Parameters ---------- prefix: Custom prefix for the comment line. When omitted the standard ``"sample=N mode=M …"`` string is generated automatically. Returns ------- Multi-line string (no trailing newline). """ if not prefix: prefix = f"sample={self.sample_index} mode={self.mode}" if self.mode == "shell" and self.center_sym: prefix += f" center={self.center_sym}(Z={ATOMIC_NUMBERS[self.center_sym]})" if self.seed is not None: prefix += f" seed={self.seed}" return format_xyz( self.atoms, self.positions, self.charge, self.mult, self.metrics, prefix, )
[docs] def write_xyz(self, path: str | Path, *, append: bool = True) -> None: """Write this structure to an XYZ file. Parameters ---------- path: Output file path. append: If ``True`` (default) the file is opened in append mode so that multiple structures can be written in sequence. Use ``append=False`` to overwrite. """ mode = "a" if append else "w" with Path(path).open(mode) as fh: fh.write(self.to_xyz() + "\n")
# ------------------------------------------------------------------ # # Dunder helpers # # ------------------------------------------------------------------ #
[docs] def __len__(self) -> int: return len(self.atoms)
@property def n(self) -> int: """Number of atoms in the structure.""" return len(self.atoms) # ------------------------------------------------------------------ # # XYZ import # # ------------------------------------------------------------------ #
[docs] @classmethod def from_xyz( cls, source: str | Path, *, frame: int = 0, recompute_metrics: bool = True, cutoff: float | None = None, n_bins: int = 20, w_atom: float = 0.5, w_spatial: float = 0.5, cov_scale: float = 1.0, ) -> Structure: """Load a :class:`Structure` from an XYZ file or string. Supports both plain XYZ and PASTED extended XYZ (with ``charge=``, ``mult=``, and metric tokens on the comment line). When *recompute_metrics* is ``True`` (default), all disorder metrics are recomputed from the loaded geometry so that the returned structure is fully usable as optimizer input or for filtering. Parameters ---------- source: Path to an XYZ file **or** a raw XYZ string. frame: Zero-based frame index when *source* contains multiple concatenated structures (default: first frame). recompute_metrics: Recompute all disorder metrics after loading. Set to ``False`` to skip the recomputation and return the structure with whatever metric values were embedded in the extended XYZ comment (or an empty dict for plain XYZ). cutoff: Distance cutoff (Å) for metric computation. Auto-computed from the element pool when ``None``. n_bins: Histogram bins for ``H_spatial`` / ``RDF_dev`` (default: 20). w_atom: Weight of ``H_atom`` in ``H_total`` (default: 0.5). w_spatial: Weight of ``H_spatial`` in ``H_total`` (default: 0.5). cov_scale: Minimum distance scale factor used for metrics (default: 1.0). Returns ------- Structure Raises ------ FileNotFoundError When *source* looks like a file path (no newlines) but the path does not exist on disk. IsADirectoryError When *source* is a path that points to a directory rather than a regular file. ValueError When the XYZ content cannot be parsed, or *frame* is out of range. Examples -------- Load and immediately use as optimizer initial structure:: from pasted import Structure, StructureOptimizer s = Structure.from_xyz("my_structure.xyz") opt = StructureOptimizer( n_atoms=len(s), charge=s.charge, mult=s.mult, objective={"H_total": 1.0}, elements=[sym for sym in set(s.atoms)], max_steps=2000, seed=42, ) result = opt.run(initial=s) """ # Determine whether *source* looks like a file path or raw XYZ text. # Heuristic: a string containing no newlines is treated as a path; # a multi-line string is treated as raw XYZ content. _looks_like_path = not isinstance(source, str) or ( "\n" not in str(source) and str(source).strip() ) p = Path(source) if _looks_like_path else None if p is not None: # *source* looks like a file path — enforce explicit errors. if not p.exists(): raise FileNotFoundError(f"XYZ file not found: {p!s}") if not p.is_file(): raise IsADirectoryError(f"Expected a file path, but {p!s} is a directory.") text = p.read_text() else: text = str(source) frames = parse_xyz(text) if not frames: raise ValueError("No frames found in XYZ source.") if frame < 0 or frame >= len(frames): raise ValueError( f"frame={frame} out of range; source contains {len(frames)} frame(s)." ) atoms, positions, charge, mult, embedded_metrics = frames[frame] if recompute_metrics: if cutoff is None: radii = np.array([_cov_radius_ang(a) for a in atoms]) # O(N) approximation: median(r_i + r_j) ≈ 2 × median(r_i). # Avoids O(N² log N) pair enumeration for large structures. cutoff = cov_scale * 1.5 * float(np.median(radii)) * 2.0 metrics = compute_all_metrics( atoms, positions, n_bins, w_atom, w_spatial, cutoff, cov_scale ) else: metrics = embedded_metrics return cls( atoms=list(atoms), positions=list(positions), charge=charge, mult=mult, metrics=metrics, mode="loaded_xyz", )
@property def comp(self) -> str: """Alphabetically-sorted composition string derived from :attr:`atoms`. Elements are sorted in ascending alphabetical order by symbol and counts above one are appended as a suffix, e.g. ``'C5N2O3'``. Single-atom elements are written without a count, e.g. ``'C'`` rather than ``'C1'``. .. note:: The sort order is **alphabetical** (Python ``sorted()``), **not** Hill order (which would place C first, H second, then all other elements alphabetically). For structures containing only C, H, N, O the two orderings coincide, but elements such as Na, Fe, or Ar will appear at their alphabetical position rather than after H. For example ``['Na', 'C', 'H', 'H']`` yields ``'CH2Na'`` (alphabetical) rather than ``'CH2Na'`` (which happens to match Hill here) but ``['Ar', 'C', 'H']`` yields ``'ArCH2'`` (alphabetical) not ``'CH2Ar'`` (Hill). This property is computed on each access and is not persisted as a dataclass field. Returns ------- str Compact composition label, e.g. ``'C5N2O3'``. Examples -------- :: s.comp # 'C5N2O3' s.comp in repr(s) # True """ counts = Counter(self.atoms) return "".join(f"{sym}{n}" if n > 1 else sym for sym, n in sorted(counts.items()))
[docs] def __repr__(self) -> str: h_total = self.metrics.get("H_total", float("nan")) return f"Structure(n={len(self)}, comp={self.comp!r}, mode={self.mode!r}, H_total={h_total:.3f})"
# --------------------------------------------------------------------------- # GenerationResult # ---------------------------------------------------------------------------
[docs] @dataclass class GenerationResult: """Return value of :func:`generate` and :meth:`StructureGenerator.generate`. Behaves like a ``list[Structure]`` in all normal usage (indexing, iteration, ``len``, boolean test, ``for s in result``) while also carrying metadata about how many attempts were made and why samples were rejected. This metadata is especially useful when integrating PASTED into automated pipelines such as ASE or high-throughput workflows, where a silent empty list would be indistinguishable from a successful run that just produced no results. Attributes ---------- structures: Structures that passed all filters. n_attempted: Total placement attempts made. n_passed: Number of structures that passed all filters (equals ``len(structures)`` unless the caller mutates the list). n_rejected_parity: Attempts rejected by the charge/multiplicity parity check. n_rejected_filter: Attempts rejected by user-supplied metric filters. n_success_target: The ``n_success`` value that was in effect during generation (``None`` when not set). Examples -------- Drop-in replacement for ``list[Structure]``:: result = generate(n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:8", elements="6,7,8", n_samples=20, seed=0) for s in result: # iterates like a list print(s.to_xyz()) print(len(result)) # number that passed Inspect rejection metadata:: if result.n_rejected_parity > 0: print(f"{result.n_rejected_parity} samples failed parity check") print(result.summary()) Notes ----- ``GenerationResult`` is a :func:`~dataclasses.dataclass`; downstream code should treat it as immutable. The ``structures`` field is a plain ``list`` and may be sorted or sliced freely. """ structures: list[Structure] = field(default_factory=list) n_attempted: int = 0 n_passed: int = 0 n_rejected_parity: int = 0 n_rejected_filter: int = 0 n_success_target: int | None = None # ------------------------------------------------------------------ # # list-compatible interface # # ------------------------------------------------------------------ #
[docs] def __len__(self) -> int: return len(self.structures)
[docs] def __iter__(self) -> Iterator[Structure]: return iter(self.structures)
def __getitem__(self, index: int | slice) -> Structure | list[Structure]: if isinstance(index, slice): return self.structures[index] return self.structures[index] def __bool__(self) -> bool: return bool(self.structures) def __add__(self, other: GenerationResult) -> GenerationResult: """Merge two :class:`GenerationResult` objects into one. Combines structures and accumulates all counters so that batch workflows can collect results across multiple calls and treat them as a single result:: r1 = generate(..., n_samples=20, seed=0) r2 = generate(..., n_samples=20, seed=1) combined = r1 + r2 print(len(combined)) # up to 40 print(combined.summary()) Structure order is preserved: all structures from *self* appear before those from *other*. Returning :data:`NotImplemented` for non-``GenerationResult`` operands lets Python fall back to the reflected ``__radd__`` of the right-hand side, following the standard Python operator protocol. Parameters ---------- other: Another :class:`GenerationResult` to merge into this one. Returns ------- GenerationResult New result containing all structures from both operands with all counters summed. ``n_success_target`` is taken from *self* when set, otherwise from *other*. Raises ------ NotImplemented Returned (not raised) when *other* is not a :class:`GenerationResult`, allowing Python to try the reflected operation on the right-hand operand. """ if not isinstance(other, GenerationResult): return NotImplemented return GenerationResult( structures=self.structures + other.structures, n_attempted=self.n_attempted + other.n_attempted, n_passed=self.n_passed + other.n_passed, n_rejected_parity=self.n_rejected_parity + other.n_rejected_parity, n_rejected_filter=self.n_rejected_filter + other.n_rejected_filter, n_success_target=self.n_success_target if self.n_success_target is not None else other.n_success_target, )
[docs] def __repr__(self) -> str: return ( f"GenerationResult(" f"passed={self.n_passed}, " f"attempted={self.n_attempted}, " f"rejected_parity={self.n_rejected_parity}, " f"rejected_filter={self.n_rejected_filter})" )
# ------------------------------------------------------------------ # # Metadata helpers # # ------------------------------------------------------------------ #
[docs] def summary(self) -> str: """Return a human-readable one-line summary of the generation run. Returns ------- str E.g. ``"passed=5 attempted=20 rejected_parity=2 rejected_filter=13"``. """ parts = [ f"passed={self.n_passed}", f"attempted={self.n_attempted}", f"rejected_parity={self.n_rejected_parity}", f"rejected_filter={self.n_rejected_filter}", ] if self.n_success_target is not None: parts.append(f"n_success_target={self.n_success_target}") return " ".join(parts)
# --------------------------------------------------------------------------- # StructureGenerator # ---------------------------------------------------------------------------
[docs] class StructureGenerator: """Generate random atomic structures with disorder metrics. All parameters use Python snake_case names that correspond 1-to-1 with their CLI ``--flag`` counterparts. Parameters ---------- n_atoms: Number of atoms per structure (before optional H augmentation). charge: Total system charge (applied to every structure). mult: Spin multiplicity 2S+1. mode: Placement mode: ``"gas"`` (default), ``"chain"``, ``"shell"``, or ``"maxent"``. region: Bounding-region spec: ``"sphere:R"`` | ``"box:L"`` | ``"box:LX,LY,LZ"``. **Required when** *mode* **is** ``"gas"`` **or** ``"maxent"``; ignored for ``"chain"`` and ``"shell"`` (those modes use their own geometry parameters such as *shell_radius* and *bond_range*). Example: ``region="sphere:8"`` places atoms inside an 8 Å-radius sphere. branch_prob: [chain] Branching probability (default: 0.3). chain_persist: [chain] Directional persistence ∈ [0, 1] (default: 0.5). chain_bias: [chain] Global-axis drift strength ∈ [0, 1] (default: 0.0). The direction of the first bond becomes the bias axis; each subsequent step is blended toward that axis before normalization. 0.0 → no bias (backwards-compatible); higher values produce more elongated structures with larger ``shape_aniso``. bond_range: [chain / shell tails] Bond-length range in Å (default: ``(1.2, 1.6)``). center_z: [shell] Atomic number of center atom. ``None`` → random per sample. coord_range: [shell] Coordination-number range (default: ``(4, 8)``). shell_radius: [shell] Shell-radius range in Å (default: ``(1.8, 2.5)``). elements: Element pool. Three forms are accepted: * **Atomic-number spec string** — a comma-separated list of integers and/or integer ranges, e.g. ``"6,7,8"`` (C, N, O) or ``"1-30"`` (H to Zn) or ``"1-10,26,28"`` (H–Ne plus Fe and Ni). Ranges are inclusive. **Symbol strings such as** ``"C,N,O"`` **are not accepted** and will raise :exc:`ValueError`; use the numeric form ``"6,7,8"`` or pass a list instead. * **Explicit list of element symbols** — e.g. ``["C", "N", "O"]`` or ``["Cr", "Mn", "Fe", "Co", "Ni"]``. Symbols must be valid two-character-or-less IUPAC symbols recognised by PASTED. * ``None`` — all Z = 1–106 (default). element_fractions: Relative sampling weights for elements in the pool, as a ``{symbol: weight}`` dict (e.g. ``{"C": 0.5, "N": 0.3, "O": 0.2}``). Weights are *relative* — they are normalized internally and need not sum to 1. Elements absent from the dict receive a weight of 1.0. When ``None`` (default), every element in the pool is sampled with equal probability. element_min_counts: Minimum number of atoms per element guaranteed in every generated structure (e.g. ``{"C": 2, "N": 1}``). The required atoms are placed first; remaining slots are filled by weighted random sampling. ``None`` (default) → no lower bounds. The sum of all minimum counts must not exceed ``n_atoms``. element_max_counts: Maximum number of atoms allowed per element (e.g. ``{"N": 5, "O": 3}``). Elements that have reached their cap are excluded from sampling for the remaining slots. ``None`` (default) → no upper bounds. .. note:: When both *element_min_counts* and *element_max_counts* are given, each element's min must be ≤ its max. .. note:: The automatic hydrogen augmentation step (``add_hydrogen=True``) runs *after* the constrained sampling and may temporarily exceed *element_max_counts* for H. Set ``add_hydrogen=False`` if H count limits are critical. cov_scale: Minimum-distance scale factor: ``d_min(i,j) = cov_scale × (r_i + r_j)`` using Pyykkö (2009) single-bond covalent radii. Default: ``1.0``. relax_cycles: Maximum repulsion-relaxation iterations (default: 1500). add_hydrogen: Automatically append H atoms when H is in the pool but the sampled composition contains none (default: ``True``). affine_strength: Global dimensionless scale of the affine transformation applied to every generated structure **before** :func:`relax_positions` (default: ``0.0`` = disabled). When > 0 a random stretch/compress + shear is applied once per structure, creating more anisotropic initial geometries before the repulsion-relaxation step. Practical range: 0.05–0.4. At 0.1 the structure is stretched / compressed by up to ±10 % along a random axis and sheared by up to ±5 %. Works identically across all placement modes (``gas``, ``chain``, ``shell``, ``maxent``). ``0.0`` preserves the behavior of all versions prior to v0.2.3. Use *affine_stretch*, *affine_shear*, and *affine_jitter* to override individual operation strengths independently. affine_stretch: Strength of the stretch/compress operation only ∈ (0, 1). When ``None`` (default) *affine_strength* is used. Set to ``0.0`` to disable stretching while keeping shear and jitter active. affine_shear: Strength of the shear operation only ∈ (0, 1). When ``None`` (default) *affine_strength* is used. Set to ``0.0`` to disable shearing while keeping stretch and jitter active. affine_jitter: Per-atom jitter scale ∈ (0, 1) relative to the move step. When ``None`` (default) *affine_strength* is used. For :class:`StructureGenerator` the move step is always ``0.0``, so jitter is never applied during generation regardless of this value; the parameter exists for symmetry with :class:`StructureOptimizer`. n_samples: Maximum number of placement attempts (default: 1). Use ``0`` to allow unlimited attempts (only valid when *n_success* is also set, otherwise a :exc:`ValueError` is raised). n_success: Target number of structures that must pass all filters before generation stops (default: ``None``). - ``None`` → generate exactly *n_samples* attempts and return all that passed (original behavior). - ``N > 0`` with ``n_samples > 0`` → stop as soon as *N* structures pass **or** *n_samples* attempts are exhausted, whichever comes first. Returns the structures collected so far with a warning if fewer than *N* were found. - ``N > 0`` with ``n_samples = 0`` → unlimited attempts; stop only when *N* structures have passed. seed: Random seed for reproducibility (``None`` → non-deterministic). n_bins: Histogram bins for ``H_spatial`` and ``RDF_dev`` (default: 20). w_atom: Weight of ``H_atom`` in ``H_total`` (default: 0.5). w_spatial: Weight of ``H_spatial`` in ``H_total`` (default: 0.5). cutoff: Distance cutoff in Å for Steinhardt and graph metrics. ``None`` → auto-computed as ``cov_scale × 1.5 × median(r_i + r_j)`` over the element pool. filters: Filter strings of the form ``"METRIC:MIN:MAX"`` (use ``"-"`` for an open bound). Only structures satisfying *all* filters are returned. verbose: Print progress and statistics to *stderr* (default: ``False``). The CLI always passes ``True``; library callers usually leave it off. Examples -------- Class API (config-based, recommended):: from pasted import GeneratorConfig, StructureGenerator cfg = GeneratorConfig( n_atoms=12, charge=0, mult=1, mode="gas", region="sphere:9", elements="1-30", n_samples=50, seed=42, filters=["H_total:2.0:-"], ) gen = StructureGenerator(cfg) structures = gen.generate() for s in structures: print(s) Functional API (keyword-based, backward-compatible):: from pasted import generate structures = generate( n_atoms=12, charge=0, mult=1, mode="chain", elements="6,7,8", n_samples=20, seed=0, ) """ def __init__( self, config: GeneratorConfig | None = None, *, n_atoms: int | None = None, **kwargs: Any, ) -> None: """Construct a :class:`StructureGenerator`. Two calling conventions are accepted: **Config-based (recommended):** gen = StructureGenerator(GeneratorConfig(n_atoms=12, charge=0, mult=1, ...)) **Keyword-based (backward-compatible):** gen = StructureGenerator(n_atoms=12, charge=0, mult=1, ...) When *config* is given all other arguments are ignored. When *config* is ``None`` a :class:`GeneratorConfig` is built from the keyword arguments; ``n_atoms``, ``charge``, and ``mult`` are required. """ if config is None: if n_atoms is None: raise TypeError( "StructureGenerator requires either a GeneratorConfig as the " "first argument, or keyword arguments including n_atoms=, " "charge=, and mult=." ) config = GeneratorConfig(n_atoms=n_atoms, **kwargs) self._cfg = config cfg = config # local alias for brevity inside __init__ # ── Mode / region validation ───────────────────────────────────── if cfg.mode not in ("gas", "chain", "shell", "maxent"): raise ValueError( f"mode must be 'gas', 'chain', 'shell', or 'maxent'; got {cfg.mode!r}" ) if cfg.mode in ("gas", "maxent") and cfg.region is None: raise ValueError( f"region is required when mode={cfg.mode!r}. " 'Pass e.g. region="sphere:8" (radius 8 Å) or ' 'region="box:10" (10×10×10 Å box).' ) # ── n_samples / n_success validation ──────────────────────────── if cfg.n_samples == 0 and cfg.n_success is None: raise ValueError( "n_samples=0 (unlimited) requires n_success to be set; " "otherwise generation would run forever." ) if cfg.n_success is not None and cfg.n_success < 1: raise ValueError(f"n_success must be >= 1; got {cfg.n_success}.") # ── Element pool ──────────────────────────────────────────────── if cfg.elements is None: self._element_pool: list[str] = default_element_pool() elif isinstance(cfg.elements, str): self._element_pool = parse_element_spec(cfg.elements) else: self._element_pool = list(cfg.elements) # ── Element fractions ──────────────────────────────────────────── if cfg.element_fractions is not None: unknown = set(cfg.element_fractions) - set(self._element_pool) if unknown: raise ValueError( f"element_fractions contains symbols not in the element pool: " f"{sorted(unknown)}" ) weights = [float(cfg.element_fractions.get(sym, 1.0)) for sym in self._element_pool] if any(w < 0 for w in weights): raise ValueError("element_fractions weights must be non-negative.") total = sum(weights) if total == 0: raise ValueError("element_fractions weights must not all be zero.") self._element_weights: list[float] = [w / total for w in weights] else: n = len(self._element_pool) self._element_weights = [1.0 / n] * n # ── Element min/max counts ─────────────────────────────────────── if cfg.element_min_counts is not None: unknown_min = set(cfg.element_min_counts) - set(self._element_pool) if unknown_min: raise ValueError( f"element_min_counts contains symbols not in the element pool: " f"{sorted(unknown_min)}" ) if any(v < 0 for v in cfg.element_min_counts.values()): raise ValueError("element_min_counts values must be non-negative.") total_min = sum(cfg.element_min_counts.values()) if total_min > cfg.n_atoms: raise ValueError( f"Sum of element_min_counts ({total_min}) exceeds n_atoms ({cfg.n_atoms})." ) if cfg.element_max_counts is not None: unknown_max = set(cfg.element_max_counts) - set(self._element_pool) if unknown_max: raise ValueError( f"element_max_counts contains symbols not in the element pool: " f"{sorted(unknown_max)}" ) if any(v < 0 for v in cfg.element_max_counts.values()): raise ValueError("element_max_counts values must be non-negative.") if cfg.element_min_counts is not None and cfg.element_max_counts is not None: for sym in cfg.element_min_counts: lo = cfg.element_min_counts[sym] hi = cfg.element_max_counts.get(sym, lo) if lo > hi: raise ValueError( f"element_min_counts[{sym!r}]={lo} > element_max_counts[{sym!r}]={hi}." ) self._element_min_counts: dict[str, int] = dict(cfg.element_min_counts or {}) self._element_max_counts: dict[str, int] = dict(cfg.element_max_counts or {}) # ── Density validation (v0.4.4) ───────────────────────────────── # _validate_density returns an effective region (possibly auto-scaled # when the user-supplied region would exceed packing thresholds). if cfg.mode in ("gas", "maxent") and cfg.region is not None: self._effective_region: str = self._validate_density( cfg.n_atoms, cfg.region ) else: self._effective_region = cfg.region or "" # ── v0.4.5 hot-loop caches ────────────────────────────────────── self._target_parity: int = (cfg.mult - 1) % 2 self._charge_parity: int = cfg.charge & 1 # Pre-classify pool elements by Z parity (used by _adjust_parity strategy 3) self._odd_pool: list[str] = [ s for s in self._element_pool if ATOMIC_NUMBERS[s] % 2 == 1 ] self._even_pool: list[str] = [ s for s in self._element_pool if ATOMIC_NUMBERS[s] % 2 == 0 ] # Pre-parse region geometry for _add_h_fast volume cap _rv = self._region_volume(self._effective_region) _h_vol: float = (4.0 / 3.0) * math.pi * _H_COV_RADIUS ** 3 if _rv is not None and self._element_pool: _mean_pool_r = ( sum(_cov_radius_ang(s) for s in self._element_pool) / len(self._element_pool) ) self._h_region_vol: float = _rv[1] self._h_heavy_vol_per_atom: float = (4.0 / 3.0) * math.pi * _mean_pool_r ** 3 self._h_region_known: bool = True else: self._h_region_vol = 0.0 self._h_heavy_vol_per_atom = 0.0 self._h_region_known = False self._h_vol_const: float = _h_vol # ── Filters ───────────────────────────────────────────────────── self._filters: list[tuple[str, float, float]] = [ parse_filter(f) for f in (cfg.filters or []) ] # ── Cutoff ────────────────────────────────────────────────────── self._cutoff: float = self._resolve_cutoff(cfg.cutoff) # ── Shell center ───────────────────────────────────────────────── self._fixed_center_sym: str | None = None if cfg.mode == "shell" and cfg.center_z is not None: if cfg.center_z not in _Z_TO_SYM: raise ValueError(f"center_z={cfg.center_z}: unknown atomic number.") sym = _Z_TO_SYM[cfg.center_z] if sym not in self._element_pool: raise ValueError(f"center_z={cfg.center_z} ({sym}) is not in the element pool.") self._fixed_center_sym = sym if cfg.verbose: self._log(f"[pool] {len(self._element_pool)} elements in pool") if cfg.mode == "shell": if self._fixed_center_sym: self._log( f"[shell] center fixed: {self._fixed_center_sym} " f"(Z={ATOMIC_NUMBERS[self._fixed_center_sym]})" ) else: self._log("[shell] center: random per sample (chaos mode)") # ------------------------------------------------------------------ # # Internal helpers # # ------------------------------------------------------------------ # def _log(self, msg: str) -> None: """Print *msg* to stderr when verbose mode is active.""" print(msg, file=sys.stderr) # ------------------------------------------------------------------ # # Parity adjustment (v0.4.4) # # ------------------------------------------------------------------ # def _adjust_parity(self, atoms: list[str], rng: random.Random) -> list[str]: """Nudge atom list so charge/mult parity is satisfied (v0.4.5: optimised). Optimisations vs v0.4.4: * Z-sum parity via XOR accumulation — avoids ``dict.get`` on every atom and discards the full integer sum (only the low bit matters). * Strategy-3 candidate lists come from ``self._odd_pool`` / ``self._even_pool`` cached at construction time instead of a per-call list comprehension over the pool. * Strategy-3 index selection uses ``rng.randrange`` so only a constant number of atoms are examined in the typical case (mixed- parity pool → first candidate always succeeds). Strategy (priority order, unchanged from v0.4.4): 1. Already correct → return unchanged. 2. H in pool → add or remove one H to flip electron-count parity. 3. No H → swap one atom with a pool element of opposite-parity Z. Changes at most one atom so element-fraction distribution is minimally perturbed. """ # ── XOR bit trick: zp = z_sum % 2 without computing z_sum ───────── zp = 0 for a in atoms: zp ^= ATOMIC_NUMBERS[a] & 1 if (zp ^ self._charge_parity) == self._target_parity: return atoms # ── Strategy 2: H available ─────────────────────────────────────── if "H" in self._element_pool: if "H" not in atoms: return [*atoms, "H"] elif rng.random() < 0.5: idx = next(i for i, a in enumerate(atoms) if a == "H") return atoms[:idx] + atoms[idx + 1:] else: return [*atoms, "H"] # ── Strategy 3: swap one atom — O(1) amortised ─────────────────── # Pre-cached pools eliminate the per-call list comprehension. # rng.randrange picks a random starting atom; in a mixed-parity pool # the first candidate always succeeds, giving O(1) average cost. n = len(atoms) if not n: return atoms start = rng.randrange(n) for offset in range(n): i = (start + offset) % n needed = (ATOMIC_NUMBERS[atoms[i]] & 1) ^ 1 cands = self._odd_pool if needed == 1 else self._even_pool if cands: result = list(atoms) result[i] = rng.choice(cands) return result return atoms # pool has only one parity class # ------------------------------------------------------------------ # # Fast hydrogen augmentation (v0.4.5) # # ------------------------------------------------------------------ # def _add_h_fast(self, atoms: list[str], rng: random.Random) -> list[str]: """Hydrogen augmentation using construction-time caches (v0.4.5 fast path). Semantically equivalent to :func:`~pasted._placement.add_hydrogen` called with ``region=self._effective_region``, ``charge``, and ``mult`` from the config, but replaces two per-sample O(N) scans with O(1) cache lookups: * Region-volume string parsing → ``self._h_region_vol`` (pre-parsed). * Mean covalent radius of placed atoms → pool mean stored in ``self._h_heavy_vol_per_atom`` (exact in expectation since atoms are drawn from the same pool; acts as a proxy for the per-sample mean with negligible error for large samples). The Z-sum parity check still scans *atoms* (unavoidable), but uses the XOR bit trick to avoid computing the full integer sum (~2× faster than the ``dict.get`` accumulation in v0.4.4). """ if "H" in atoms: return atoms n = len(atoms) # ── Volume cap (O(1) with cached geometry) ───────────────────────── if self._h_region_known and n > 0: available = self._h_region_vol * _PACKING_HARD - n * self._h_heavy_vol_per_atom max_h = max(0, int(available / self._h_vol_const)) else: max_h = 1 + round(n * 1.2) # ── Raw sample (same distribution as v0.4.3 / add_hydrogen) ──────── n_h = min(1 + round(rng.random() * n * 1.2), max_h) # ── Parity adjustment — XOR bit trick ─────────────────────────────── # (z_sum - charge + n_h) % 2 == target_parity # ⟺ (zp ^ charge_parity ^ (n_h & 1)) == target_parity zp = 0 for a in atoms: zp ^= ATOMIC_NUMBERS[a] & 1 if (zp ^ self._charge_parity ^ (n_h & 1)) != self._target_parity: if n_h > 0: n_h -= 1 elif n_h < max_h: n_h += 1 return atoms + ["H"] * max(0, n_h) # ------------------------------------------------------------------ # # Density validation (v0.4.4) # # ------------------------------------------------------------------ # @staticmethod def _region_volume(region: str) -> tuple[str, float] | None: """Return (shape, volume_ų) for a region spec, or None if unknown.""" if region.startswith("sphere:"): r = float(region.split(":")[1]) return "sphere", (4 / 3) * math.pi * r ** 3 if region.startswith("box:"): dims = list(map(float, region.split(":")[1].split(","))) if len(dims) == 1: dims *= 3 return "box", dims[0] * dims[1] * dims[2] return None @staticmethod def _recommend_region(n_atoms: int, mean_r: float, shape: str, target_packing: float = 0.45) -> str: """Return a region string at *target_packing* for *n_atoms* atoms.""" atom_vol = (4 / 3) * math.pi * mean_r ** 3 V = n_atoms * atom_vol / target_packing margin = 1.05 if shape == "sphere": r = ((3 * V) / (4 * math.pi)) ** (1 / 3) * margin return f"sphere:{r:.1f}" L = V ** (1 / 3) * margin return f"box:{L:.1f}" def _validate_density(self, n_atoms: int, region: str) -> str: """Check packing fraction and auto-scale the region when it is too high. If the packing fraction of *n_atoms* atoms inside *region* would exceed a safety threshold the region is **automatically enlarged** to bring the packing fraction back to the target value (``_PACKING_TARGET = 0.45``), and a :class:`UserWarning` is emitted. This prevents a silent failure in :func:`relax_positions` without forcing the caller to manually calculate a safe region size. Two thresholds govern the behaviour: Warn limit (0.50) Practical threshold above which relax_positions slows significantly due to FlatCellList cell over-population. Region is auto-scaled; UserWarning emitted. Hard limit (0.64) Random-close-packing limit for monodisperse spheres. relax_positions cannot converge above this density. Region is auto-scaled with a stronger UserWarning. Both sphere and box region specs are supported; unrecognised specs are returned unchanged (no scaling is attempted). Parameters ---------- n_atoms: Number of atoms to be placed. region: Bounding-region spec (``"sphere:R"`` | ``"box:L"`` | ``"box:LX,LY,LZ"``). Returns ------- str The effective region spec to use for placement. Equal to *region* when the packing fraction is within limits; otherwise the auto-scaled spec string. """ result = self._region_volume(region) if result is None: return region shape, region_vol = result mean_r = float(np.mean([_cov_radius_ang(s) for s in self._element_pool])) atom_vol = (4 / 3) * math.pi * mean_r ** 3 pf = n_atoms * atom_vol / region_vol if pf > _PACKING_HARD: scaled = self._recommend_region(n_atoms, mean_r, shape) warnings.warn( f"[pasted v0.4.4] Packing fraction {pf:.0%} exceeds the physical " f"limit ({_PACKING_HARD:.0%}, random-close-packing) for " f"{region!r} with {n_atoms:,} atoms. " f"relax_positions cannot converge at this density. " f"Region auto-scaled to {scaled!r} (target packing: 45%).", UserWarning, stacklevel=3, ) return scaled if pf > _PACKING_WARN: scaled = self._recommend_region(n_atoms, mean_r, shape) warnings.warn( f"[pasted v0.4.4] Packing fraction {pf:.0%} exceeds recommended " f"limit ({_PACKING_WARN:.0%}) for {region!r} with {n_atoms:,} atoms. " f"relax_positions may be slow. " f"Region auto-scaled to {scaled!r} (target packing: 45%).", UserWarning, stacklevel=3, ) return scaled return region # ------------------------------------------------------------------ # # Verbose logging helpers # # ------------------------------------------------------------------ # def _log_filter_header(self) -> None: """Log the active filter bounds to stderr (verbose mode only). Emits a single ``[filter]`` line listing every active filter in the form ``METRIC in [lo, hi]``. Does nothing when there are no filters or verbose mode is off. """ if self._cfg.verbose and self._filters: self._log( "[filter] " + ", ".join(f"{m} in [{lo:.4g},{hi:.4g}]" for m, lo, hi in self._filters) ) def _log_sample_result( self, i: int, width: int, denom: str, flag: str, *, metrics: dict[str, float] | None = None, msg: str | None = None, ) -> None: """Log one sample outcome to stderr (verbose mode only). Parameters ---------- i: Zero-based attempt index. width: Field width for left-padding the attempt counter. denom: Denominator string (e.g. ``"20"`` or ``"∞"``). flag: Short status tag: ``"PASS"``, ``"skip"``, ``"invalid"``, or ``"warn"``. metrics: Metric dict to append when *flag* is ``"PASS"`` or ``"skip"``. msg: Free-form message to append (used for ``"invalid"`` and ``"warn"``). """ if not self._cfg.verbose: return prefix = f"[{i + 1:>{width}}/{denom}:{flag}]" if metrics is not None: self._log(prefix + " " + " ".join(f"{k}={_fmt(v)}" for k, v in metrics.items())) elif msg is not None: self._log(f"{prefix} {msg}") else: self._log(prefix) def _log_summary( self, n_attempted: int, n_passed: int, n_invalid: int, n_rejected_filter: int, ) -> None: """Log the end-of-run summary line to stderr (verbose mode only). Parameters ---------- n_attempted: Total placement attempts made. n_passed: Number of structures that passed all filters. n_invalid: Attempts rejected by the charge/multiplicity parity check. n_rejected_filter: Attempts rejected by metric filters. """ if not self._cfg.verbose: return self._log( f"[summary] attempted={n_attempted} passed={n_passed} " f"rejected_parity={n_invalid} rejected_filter={n_rejected_filter}" ) def _resolve_cutoff(self, override: float | None) -> float: if override is not None: if self._cfg.verbose: self._log(f"[cutoff] {override:.3f} Å (user-specified)") return override radii = np.array([_cov_radius_ang(s) for s in self._element_pool]) # O(N) approximation: median(r_i + r_j) ≈ 2 × median(r_i). # The element pool is at most 106 elements, so O(N²) would be fast here # too; we still use the O(N) form for consistency with _metrics.py and # to match the formula documented in architecture.md (v0.2.6). median_sum = float(np.median(radii)) * 2.0 cutoff = self._cfg.cov_scale * 1.5 * median_sum if self._cfg.verbose: self._log( f"[cutoff] {cutoff:.3f} Å (auto: cov_scale={self._cfg.cov_scale} × 1.5 × " f"median(r_i+r_j)≈{median_sum:.3f} Å)" ) return cutoff def _sample_atoms(self, rng: random.Random) -> list[str]: """Sample *n_atoms* element symbols respecting fractions and count bounds. Algorithm --------- 1. If no fractions/min/max are configured, falls back to the original uniform ``rng.choice`` per atom (preserves seed parity). 2. Otherwise: place the guaranteed minimum-count atoms first (``element_min_counts``), fill remaining slots by weighted random sampling (``element_fractions``), excluding elements that have reached their ``element_max_counts`` cap, then shuffle. Raises ------ RuntimeError When the constraints cannot be satisfied (e.g. all remaining elements are capped and there are still slots to fill). """ pool = self._element_pool min_c = self._element_min_counts max_c = self._element_max_counts n = len(pool) uniform = n > 0 and all(abs(w - 1.0 / n) < 1e-12 for w in self._element_weights) # Fast path: uniform weights, no bounds → identical to original behavior if uniform and not min_c and not max_c: return [rng.choice(pool) for _ in range(self._cfg.n_atoms)] weights = self._element_weights # ── Step 1: fill guaranteed minimum counts ────────────────────── counts: dict[str, int] = {sym: min_c.get(sym, 0) for sym in pool} atoms: list[str] = [] for sym in pool: atoms.extend([sym] * counts[sym]) remaining = self._cfg.n_atoms - len(atoms) # ── Step 2: weighted sampling for remaining slots ──────────────── for _ in range(remaining): # Build eligible pool (not yet capped) eligible: list[str] = [] eligible_w: list[float] = [] for sym, w in zip(pool, weights, strict=True): cap = max_c.get(sym, None) if cap is None or counts.get(sym, 0) < cap: eligible.append(sym) eligible_w.append(w) if not eligible: raise RuntimeError( "element_max_counts constraints cannot be satisfied: " "all elements are capped before n_atoms is reached." ) # Normalise eligible weights and do a weighted choice total_w = sum(eligible_w) cum: list[float] = [] acc = 0.0 for w in eligible_w: acc += w / total_w cum.append(acc) r = rng.random() chosen = eligible[-1] for sym, c in zip(eligible, cum, strict=True): if r <= c: chosen = sym break counts[chosen] = counts.get(chosen, 0) + 1 atoms.append(chosen) # ── Step 3: shuffle so forced atoms don't cluster at front ─────── rng.shuffle(atoms) return atoms # ------------------------------------------------------------------ # # Public properties # # ------------------------------------------------------------------ # @property def element_pool(self) -> list[str]: """A copy of the resolved element pool (list of symbols).""" return list(self._element_pool) @property def cutoff(self) -> float: """Distance cutoff in Å used for Steinhardt and graph metrics.""" return self._cutoff @property def config(self) -> GeneratorConfig: """The :class:`GeneratorConfig` that was used to construct this generator.""" return self._cfg def __getattr__(self, name: str) -> Any: """Proxy attribute access to ``_cfg`` for all :class:`GeneratorConfig` fields. This allows code written against the old kwargs-based API (e.g. ``gen.n_atoms``, ``gen.seed``) to continue working without modification after the migration to config-based construction. """ # Avoid infinite recursion on _cfg itself during __init__ if name == "_cfg": raise AttributeError(name) try: return getattr(self._cfg, name) except AttributeError: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) from None # ------------------------------------------------------------------ # # Generation # # ------------------------------------------------------------------ # # ------------------------------------------------------------------ # # Internal placement dispatch # # ------------------------------------------------------------------ # def _place_one( self, atoms_list: list[str], rng: random.Random, ) -> tuple[list[str], list[Vec3], str | None]: """Run the mode-specific placement and return (atoms, positions, center_sym). Raises ------ RuntimeError, ValueError Propagated from the underlying placement functions. """ bond_lo, bond_hi = self._cfg.bond_range shell_lo, shell_hi = self._cfg.shell_radius coord_lo, coord_hi = self._cfg.coord_range center_sym: str | None = None if self._cfg.mode == "gas": # Use _effective_region (auto-scaled if density was too high) atoms_out, positions = place_gas( atoms_list, self._effective_region, rng, ) elif self._cfg.mode == "chain": atoms_out, positions = place_chain( atoms_list, bond_lo, bond_hi, self._cfg.branch_prob, self._cfg.chain_persist, rng, chain_bias=self._cfg.chain_bias, ) elif self._cfg.mode == "maxent": # Use _effective_region (auto-scaled if density was too high) atoms_out, positions = place_maxent( atoms_list, self._effective_region, self._cfg.cov_scale, rng, maxent_steps=self._cfg.maxent_steps, maxent_lr=self._cfg.maxent_lr, maxent_cutoff_scale=self._cfg.maxent_cutoff_scale, trust_radius=self._cfg.trust_radius, convergence_tol=self._cfg.convergence_tol, seed=self._cfg.seed, ) else: # shell center_sym = ( self._fixed_center_sym if self._fixed_center_sym is not None else rng.choice(atoms_list) ) atoms_out, positions = place_shell( atoms_list, center_sym, coord_lo, coord_hi, shell_lo, shell_hi, bond_lo, bond_hi, rng, ) return atoms_out, positions, center_sym # ------------------------------------------------------------------ # # Generation # # ------------------------------------------------------------------ # def _stream_with_stats(self) -> tuple[Iterator[Structure], dict[str, int]]: """Run the generation loop and expose both structures and run statistics. Returns a ``(structures_iterator, stats_dict)`` pair. The *stats_dict* is a mutable mapping that is populated **in-place** as the iterator is consumed; callers that need statistics must exhaust the iterator before reading from it. This is the single source of truth for all generation logic. :meth:`stream` and :meth:`generate` are thin wrappers around it, which eliminates the hidden coupling that previously existed via the ``_last_run_stats`` instance variable. Returns ------- tuple[Iterator[Structure], dict[str, int]] ``(it, stats)`` where *it* yields passing structures and *stats* is populated with ``n_attempted``, ``n_passed``, ``n_rejected_parity``, and ``n_rejected_filter`` once *it* is exhausted. """ stats: dict[str, int] = {} def _inner() -> Iterator[Structure]: rng = random.Random(self._cfg.seed) self._log_filter_header() do_add_h = ("H" in self._element_pool) and self._cfg.add_hydrogen n_passed = n_invalid = n_attempted = n_rejected_filter = 0 unlimited = self._cfg.n_samples == 0 denom = "∞" if unlimited else str(self._cfg.n_samples) width = len(denom) while True: # Stop conditions if not unlimited and n_attempted >= self._cfg.n_samples: break if self._cfg.n_success is not None and n_passed >= self._cfg.n_success: break i = n_attempted n_attempted += 1 atoms_list = self._sample_atoms(rng) if do_add_h: atoms_list = self._add_h_fast(atoms_list, rng) # _adjust_parity may add/remove/swap atoms, which would # violate explicit element_min_counts / element_max_counts # set by the caller. Skip when such constraints are active. if not self._element_min_counts and not self._element_max_counts: atoms_list = self._adjust_parity(atoms_list, rng) ok, val_msg = validate_charge_mult(atoms_list, self._cfg.charge, self._cfg.mult) if not ok: n_invalid += 1 self._log_sample_result(i, width, denom, "invalid", msg=val_msg) continue try: atoms_out, positions, center_sym = self._place_one(atoms_list, rng) except (RuntimeError, ValueError) as exc: if self._cfg.verbose: self._log(f"[ERROR] sample {i + 1}: {exc}") raise # ── Optional affine transform (applied once, before relax) ── if self._cfg.affine_strength > 0.0: positions = _affine_move( positions, 0.0, self._cfg.affine_strength, rng, affine_stretch=self._cfg.affine_stretch, affine_shear=self._cfg.affine_shear, affine_jitter=self._cfg.affine_jitter, ) positions, converged = relax_positions( atoms_out, positions, self._cfg.cov_scale, self._cfg.relax_cycles, seed=self._cfg.seed, ) if not converged: self._log_sample_result( i, width, denom, "warn", msg=( f"relax_positions did not converge in {self._cfg.relax_cycles} cycles." ), ) metrics = compute_all_metrics( atoms_out, positions, self._cfg.n_bins, self._cfg.w_atom, self._cfg.w_spatial, self._cutoff, self._cfg.cov_scale, ) passed = passes_filters(metrics, self._filters) self._log_sample_result( i, width, denom, "PASS" if passed else "skip", metrics=metrics, ) if not passed: n_rejected_filter += 1 continue n_passed += 1 yield Structure( atoms=atoms_out, positions=positions, charge=self._cfg.charge, mult=self._cfg.mult, metrics=metrics, mode=self._cfg.mode, sample_index=n_passed, center_sym=center_sym if self._cfg.mode == "shell" else None, seed=self._cfg.seed, ) n_skip = n_attempted - n_passed - n_invalid self._log_summary(n_attempted, n_passed, n_invalid, n_skip) # ── warnings.warn for noteworthy outcomes ────────────────────── # Fires regardless of verbose so that downstream consumers # (ASE, HT pipelines) receive machine-visible signals even when # PASTED is not in verbose mode. # # Parity warnings fire only when n_passed == 0 (complete failure). # Partial parity rejection where some structures still passed is # expected behavior for mixed-element pools and does not require # a warning — the verbose summary line already reports the counts. if n_invalid > 0 and n_passed == 0: if n_rejected_filter == 0: # Pure parity failure — no attempt reached the filter stage. warnings.warn( f"All {n_attempted} attempt(s) were rejected by the charge/" f"multiplicity parity check ({n_invalid} invalid). " f"No structures were generated. " f"Check that your element pool can satisfy " f"charge={self._cfg.charge}, mult={self._cfg.mult}.", UserWarning, stacklevel=4, ) else: # Mixed failure: some attempts failed parity AND some failed # filters. Report both causes so users don't only debug the # element pool. warnings.warn( f"{n_invalid} of {n_attempted} attempt(s) were rejected by the " f"charge/multiplicity parity check, and the remaining " f"{n_rejected_filter} that passed parity were rejected by metric " f"filters. No structures were generated. " f"Check your element pool (charge={self._cfg.charge}, " f"mult={self._cfg.mult}) and relax --filter thresholds.", UserWarning, stacklevel=4, ) if n_passed == 0 and n_invalid == 0: warnings.warn( f"No structures passed the metric filters after " f"{n_attempted} attempt(s) " f"({n_skip} rejected by filters). " f"Try relaxing the --filter thresholds or increasing n_samples.", UserWarning, stacklevel=4, ) elif ( self._cfg.n_success is not None and n_passed < self._cfg.n_success and not unlimited ): warnings.warn( f"Attempt budget exhausted ({n_attempted} attempts) before " f"reaching n_success={self._cfg.n_success}; " f"only {n_passed} structure(s) collected. " f"Increase n_samples or relax filters.", UserWarning, stacklevel=4, ) # Populate the shared stats dict now that the loop is complete. stats.update( { "n_attempted": n_attempted, "n_passed": n_passed, "n_rejected_parity": n_invalid, "n_rejected_filter": n_rejected_filter, } ) return _inner(), stats
[docs] def stream(self) -> Iterator[Structure]: """Generate structures one by one, yielding each that passes all filters. Unlike :meth:`generate`, structures are yielded immediately as they pass, so callers can write output or stop early without waiting for all attempts to complete. Respects both *n_samples* (maximum attempts) and *n_success* (target number of passing structures): - If *n_success* is set, the iterator stops as soon as that many structures have been yielded — even if *n_samples* attempts have not been exhausted. - If *n_samples* is ``0`` (unlimited), the iterator runs until *n_success* structures have been yielded. - If *n_samples* attempts are exhausted before *n_success* is reached, a warning is emitted to *stderr* and the iterator ends. Each call creates a fresh :class:`random.Random` seeded with ``self._cfg.seed``, so repeated calls with the same seed are reproducible. Yields ------ Structure Each structure that passed all filters, in generation order. Examples -------- Write structures to a file as they are found:: gen = StructureGenerator( n_atoms=12, charge=0, mult=1, mode="gas", region="sphere:9", elements="1-30", n_success=10, n_samples=500, seed=42, ) for s in gen.stream(): s.write_xyz("out.xyz") """ it, _ = self._stream_with_stats() return it
[docs] def generate(self) -> GenerationResult: """Generate structures and return a :class:`GenerationResult`. Collects all structures yielded by the internal generation loop, attaches generation metadata (attempt counts, rejection breakdowns), and returns a :class:`GenerationResult` that behaves like a ``list[Structure]`` in all normal usage while also carrying the diagnostics needed for automated pipelines. Run statistics (``n_attempted``, ``n_passed``, etc.) are obtained directly from :meth:`_stream_with_stats` rather than via a shared instance variable, so there is no hidden coupling between :meth:`stream` and :meth:`generate`. Calling one does not affect the other, and partial iteration of :meth:`stream` cannot leave stale counters for a subsequent :meth:`generate` call. :class:`GenerationResult` supports the full ``list`` interface (indexing, iteration, ``len``, ``bool``) so existing code that does ``result[0]`` or ``for s in result`` continues to work without modification. Warnings are also emitted via :func:`warnings.warn` (category :class:`UserWarning`) when: - Any attempts are rejected by the charge/multiplicity parity check. - No structures pass the metric filters. - The attempt budget is exhausted before ``n_success`` is reached. Each call creates a fresh :class:`random.Random` seeded with ``self._cfg.seed``, so repeated calls with the same seed are reproducible. Returns ------- GenerationResult Wraps the list of passing structures together with generation metadata. Use ``result.structures`` for the raw list or ``result.summary()`` for a one-line diagnostic string. Examples -------- Drop-in list usage:: result = gen.generate() for s in result: print(s.to_xyz()) Metadata access:: result = gen.generate() if result.n_rejected_parity > 0: print(result.summary()) """ it, stats = self._stream_with_stats() structures = list(it) # exhausts the iterator, populating stats return GenerationResult( structures=structures, n_attempted=stats.get("n_attempted", len(structures)), n_passed=stats.get("n_passed", len(structures)), n_rejected_parity=stats.get("n_rejected_parity", 0), n_rejected_filter=stats.get("n_rejected_filter", 0), n_success_target=self._cfg.n_success, )
# ------------------------------------------------------------------ # # Iteration support # # ------------------------------------------------------------------ #
[docs] def __iter__(self) -> Iterator[Structure]: """Iterate over generated structures (delegates to :meth:`stream`).""" return self.stream()
[docs] def __repr__(self) -> str: return ( f"StructureGenerator(" f"n_atoms={self._cfg.n_atoms}, mode={self._cfg.mode!r}, " f"charge={self._cfg.charge:+d}, mult={self._cfg.mult}, " f"n_samples={self._cfg.n_samples}, " f"n_success={self._cfg.n_success}, " f"pool_size={len(self._element_pool)})" )
# --------------------------------------------------------------------------- # Functional API # ---------------------------------------------------------------------------
[docs] def read_xyz( source: str | Path, *, recompute_metrics: bool = True, cutoff: float | None = None, n_bins: int = 20, w_atom: float = 0.5, w_spatial: float = 0.5, cov_scale: float = 1.0, ) -> list[Structure]: """Read one or more structures from an XYZ file or string. Convenience wrapper around :meth:`Structure.from_xyz` that reads **all frames** from a (possibly multi-frame) XYZ source and returns them as a list. Both plain XYZ and PASTED extended XYZ are supported. Parameters ---------- source: Path to an XYZ file **or** a raw XYZ string. recompute_metrics: Recompute all disorder metrics after loading each structure (default: ``True``). cutoff: Distance cutoff (Å) for metric computation. Auto-computed from each structure's element pool when ``None``. n_bins: Histogram bins for ``H_spatial`` / ``RDF_dev`` (default: 20). w_atom: Weight of ``H_atom`` in ``H_total`` (default: 0.5). w_spatial: Weight of ``H_spatial`` in ``H_total`` (default: 0.5). cov_scale: Minimum distance scale factor used for metrics (default: 1.0). Returns ------- list[Structure] One :class:`Structure` per frame, in file order. Raises ------ FileNotFoundError When *source* looks like a file path (no newlines) but the path does not exist on disk. IsADirectoryError When *source* is a path that points to a directory. ValueError When the XYZ content cannot be parsed. Examples -------- Load a PASTED output file and pass the first structure to the optimizer:: from pasted import read_xyz, StructureOptimizer structs = read_xyz("results.xyz") opt = StructureOptimizer( n_atoms=len(structs[0]), charge=structs[0].charge, mult=structs[0].mult, objective={"H_total": 1.0}, elements=list(set(structs[0].atoms)), max_steps=3000, seed=42, ) result = opt.run(initial=structs[0]) Compose with :class:`GenerationResult` via ``+``:: from pasted import read_xyz, generate existing = generate(n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:9", elements="6,7,8", n_samples=5, seed=0) loaded = read_xyz("previous_run.xyz") # loaded is a list[Structure]; wrap manually if needed: from pasted import GenerationResult all_structs = existing + GenerationResult(structures=loaded, n_passed=len(loaded), n_attempted=len(loaded)) """ p = Path(source) if not isinstance(source, str) or "\n" not in str(source) else None if p is not None: # *source* looks like a file path — enforce explicit errors, matching # Structure.from_xyz() behavior (raises FileNotFoundError / IsADirectoryError # rather than silently falling through to parse the path string as XYZ). if not p.exists(): raise FileNotFoundError(f"XYZ file not found: {p!s}") if not p.is_file(): raise IsADirectoryError(f"Expected a file path, but {p!s} is a directory.") text = p.read_text() else: text = str(source) frames = parse_xyz(text) result: list[Structure] = [] for atoms, positions, charge, mult, embedded_metrics in frames: if recompute_metrics: cut = cutoff if cut is None: radii = np.array([_cov_radius_ang(a) for a in atoms]) # O(N) approximation: median(r_i + r_j) ≈ 2 × median(r_i). cut = cov_scale * 1.5 * float(np.median(radii)) * 2.0 metrics = compute_all_metrics( atoms, positions, n_bins, w_atom, w_spatial, cut, cov_scale ) else: metrics = embedded_metrics result.append( Structure( atoms=list(atoms), positions=list(positions), charge=charge, mult=mult, metrics=metrics, mode="loaded_xyz", ) ) return result
[docs] def generate( config: GeneratorConfig | None = None, *, n_atoms: int | None = None, charge: int | None = None, mult: int | None = None, **kwargs: Any, ) -> GenerationResult: """Create a :class:`StructureGenerator` and immediately call :meth:`~StructureGenerator.generate`. Two calling conventions are supported: **Config-based (recommended for new code):** Pass a :class:`GeneratorConfig` as the first positional argument. Provides full mypy / IDE type-checking on every field:: from pasted import generate, GeneratorConfig cfg = GeneratorConfig(n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:8", elements="6,7,8", n_samples=20, seed=0) result = generate(cfg) **Keyword-based (backward-compatible, original API):** Pass all parameters as keyword arguments. ``n_atoms``, ``charge``, and ``mult`` are required; all others are optional:: result = generate( n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:8", elements="6,7,8", n_samples=20, seed=0, ) Both forms may not be mixed: if *config* is given, all other keyword arguments are ignored. Parameters ---------- config: A fully-populated :class:`GeneratorConfig` instance. When given, all other keyword arguments are ignored. n_atoms: Number of atoms per structure (**required** when *config* is ``None``). charge: Total system charge (**required** when *config* is ``None``). mult: Spin multiplicity 2S+1 (**required** when *config* is ``None``). **kwargs: Any optional :class:`GeneratorConfig` field, e.g. ``mode``, ``region``, ``elements``, ``n_samples``, ``seed``, ``filters``, ``affine_strength``, … Ignored when *config* is provided. Returns ------- GenerationResult A list-compatible object containing the structures that passed all filters plus metadata about the generation run. """ if config is not None: return StructureGenerator(config).generate() # Backward-compatible kwargs path if n_atoms is None or charge is None or mult is None: raise TypeError( "generate() requires n_atoms, charge, and mult when no GeneratorConfig is given." ) cfg = GeneratorConfig(n_atoms=n_atoms, charge=charge, mult=mult, **kwargs) return StructureGenerator(cfg).generate()