Source code for gg.sites

""" Recognizing sites to apply modifier on """

from typing import Optional, Callable, List, Union
from itertools import product
from pandas import DataFrame
import numpy as np
from ase import Atoms
from ase.constraints import FixAtoms
from ase.neighborlist import NeighborList, natural_cutoffs
import networkx as nx
from gg.utils_graph import atoms_to_graph

try:
    from scipy.spatial import Voronoi

    SCIPY_INST = True
except ImportError:
    SCIPY_INST = False


class Sites:
    """Base class for sites"""

    def __init__(
        self,
        max_bond_ratio: Optional[float] = 1.2,
        max_bond: Optional[float] = 0,
        contact_error: Optional[float] = 0.3,
    ):
        """
        Args: All the variables help in making graphs

            max_bond_ratio (float, optional): While making bonds how much error is allowed.
            Defaults to 1.2.

            max_bond (float, optional): Fixed bond distance to use, any distance above is ignored.
            Defaults to 0. If 0 , it is ignored

            contact_error (float, optional): Error allowed if atoms are too close to each other.
            Defaults to 0.2.

        """
        self.graph = None
        self.max_bond_ratio = max_bond_ratio
        self.max_bond = max_bond
        self.contact_error = contact_error

    @property
    def graph(self) -> nx.Graph:
        """
        Returns:
            nx.Graph:
        """
        return self.g

    @graph.setter
    def graph(self, g):
        self.g = g

    def get_graph(
        self, atoms: Atoms, self_interaction: bool = False, bothways: bool = True
    ) -> nx.Graph:
        """
        Args:
            atoms (_type_): _description_
            self_interaction (bool, optional): _description_. Defaults to False.
            both ways (bool, optional): _description_. Defaults to True.

        Returns:
            _type_: _description_
        """
        nl = NeighborList(
            natural_cutoffs(atoms), self_interaction=self_interaction, bothways=bothways
        )
        nl.update(atoms)
        g = atoms_to_graph(
            atoms, nl, max_bond_ratio=self.max_bond_ratio, max_bond=self.max_bond
        )
        self.graph = g
        return self.graph

    def get_sites(self, atoms: Atoms) -> list:
        """
        Returns:
            ase.Atoms:
        """
        raise NotImplementedError


[docs] class RuleSites(Sites): """A subclass of Sites that uses multiple rules to identify sites in an atomic structure.""" def __init__( self, index_parsers: Optional[ Union[Callable[[Atoms], list], List[Callable[[Atoms], list]]] ] = None, combine_rules: str = "union", max_bond_ratio: Optional[float] = 1.2, max_bond: Optional[float] = 0, contact_error: Optional[float] = 0.3, ): """ Args: index_parsers (Union[Callable[[Atoms], list], List[Callable[[Atoms], list]]], optional): A single rule or a list of rules (functions) that take an Atoms object of indices representing the sites of interest. Defaults to a function that returns all indices. combine_rules (str, optional):How to combine the results of multiple rules. Options are: - "union": Combine results using set union (default). - "intersection": Combine results using set intersection. Defaults to "union". max_bond_ratio (float, optional): While making bonds, how much error is allowed. Defaults to 1.2. max_bond (float, optional): Fixed bond distance to use, any distance above is ignored. Defaults to 0. If 0, it is ignored. contact_error (float, optional): Error allowed if atoms are too close to each other. Defaults to 0.2. """ super().__init__(max_bond_ratio, max_bond, contact_error) # Default index_parser function that returns all indices if index_parsers is None: self.index_parsers = [lambda atoms: list(range(len(atoms)))] else: # Ensure index_parsers is always a list, even if a single function is provided self.index_parsers = ( [index_parsers] if callable(index_parsers) else index_parsers ) combine_rules = combine_rules[0] if combine_rules not in ["u", "i"]: raise ValueError("combine_rules must be 'union' or 'intersection'.") self.combine_rules = combine_rules def get_sites(self, atoms: Atoms) -> list: """ Args: atoms (Atoms): The atomic structure to analyze. Returns: list: A list of indices representing the sites of interest. """ result = set() if self.combine_rules == "u" else None for parser in self.index_parsers: temp_result = set(parser(atoms)) if result is None: result = temp_result elif self.combine_rules == "u": result |= temp_result # Union elif self.combine_rules == "i": result &= temp_result # Intersection return result
# Rules Defined def get_unconstrained_sites(atoms: Atoms) -> list: """ Returns a list of indices of atoms that are not constrained. """ constrained_indices = set() for constraint in atoms.constraints: if isinstance(constraint, FixAtoms): constrained_indices.update(constraint.index) return [i for i in range(len(atoms)) if i not in constrained_indices] def get_tagged_sites(atoms: Atoms, tag: int = -1) -> list: """ Returns a list of indices of atoms that have the specified tag. """ return [i for i in range(len(atoms)) if atoms[i].tag == tag] def get_com_sites( atoms: Atoms, fraction: float = 1.0, direction: str = "above", axis: Union[str, int] = "z", ) -> list: """ Args: atoms (Atoms): The ASE Atoms object. fraction (float, optional): Fraction of the axis-distance from the COM to consider. Must be between 0 and 1. Defaults to 1.0. direction (str, optional): Whether to return atoms "above", "below", or "both". Above/below are evaluated along the selected axis, i.e., above corresponds to the +axis direction relative to the COM and below to -axis. Defaults to "above". axis (str or int, optional): Axis to use: "x", "y", "z" or 0, 1, 2. Defaults to "z". Returns: list: A list of atom indices based on the specified direction. """ if not 0 <= fraction <= 1: raise ValueError("Fraction must be between 0 and 1.") if direction not in ["above", "below", "both"]: raise ValueError("Direction must be 'above', 'below', or 'both'.") axis_map = {"x": 0, "y": 1, "z": 2} if isinstance(axis, str): if axis not in axis_map: raise ValueError("Axis must be one of 'x', 'y', 'z', 0, 1, or 2.") axis_index = axis_map[axis] elif isinstance(axis, int): if axis not in axis_map.values(): raise ValueError("Axis must be one of 'x', 'y', 'z', 0, 1, or 2.") axis_index = axis else: raise ValueError("Axis must be one of 'x', 'y', 'z', 0, 1, or 2.") if len(atoms) == 0: return [] # Get the center of mass (COM) coordinate for the chosen axis com_axis = atoms.get_center_of_mass()[axis_index] # Get the axis-coordinates of all atoms axis_values = atoms.get_positions()[:, axis_index] # Compute max and min axis-values axis_max = max(axis_values) axis_min = min(axis_values) # Calculate axis-thresholds for above and below threshold_axis_above = com_axis + (1 - fraction) * (axis_max - com_axis) threshold_axis_below = com_axis - (1 - fraction) * (com_axis - axis_min) above_indices = { i for i, axis_value in enumerate(axis_values) if axis_value > threshold_axis_above } below_indices = { i for i, axis_value in enumerate(axis_values) if axis_value < threshold_axis_below } if direction == "above": return list(above_indices) elif direction == "below": return list(below_indices) else: return list(above_indices | below_indices) def get_surface_sites_by_coordination( atoms: Atoms, max_coord: dict, max_bond_ratio: float = 1.2, max_bond: float = 0, self_interaction: bool = False, bothways: bool = True, ) -> list: """ Identifies surface sites based on coordination numbers using a graph-based approach. Args: atoms (Atoms): The ASE Atoms object. max_coord (Dict[str, int]): Dictionary of maximum coordination numbers for each element. max_bond_ratio (float, optional): Tolerance for bond distances. Defaults to 1.2. max_bond (float, optional): Maximum bond distance. Defaults to 0 (no limit). contact_error (float, optional): Tolerance for atoms being too close. Defaults to 0.2. com (float, optional): Fraction of the z-range above the center of mass to consider. Defaults to 0.1. self_interaction (bool, optional): Whether to include self-interactions in the graph. Defaults to False. bothways (bool, optional): Whether to consider bonds in both directions. Defaults to True. Returns: list: List of atom indices identified as surface sites. """ # Validate max_coord for sym in atoms.symbols: if sym not in max_coord: raise RuntimeError(f"Incomplete max_coord: Missing {sym}") # Create the graph nl = NeighborList( natural_cutoffs(atoms), self_interaction=self_interaction, bothways=bothways ) nl.update(atoms) graph = atoms_to_graph(atoms, nl, max_bond_ratio=max_bond_ratio, max_bond=max_bond) # Calculate coordination numbers and filter surface sites sites = [] for node in graph.nodes(): coord = len(list(graph[node])) # Coordination number index = graph.nodes[node]["index"] symbol = atoms[index].symbol diff_coord = max_coord[symbol] - coord if diff_coord > 0: sites.append( { "ind": index, "coord": coord, "diff_coord": diff_coord, "z_coord": atoms[index].position[2], } ) # Convert to DataFrame for easier filtering df = DataFrame(sites) if df.empty: return [] else: # Sort by coordination number and z-coordinate df = df.sort_values(by=["coord", "z_coord"]) # Return the list of indices return df["ind"].to_list() def get_surface_sites_by_voronoi_pbc( atoms: Atoms, rem_symbols: list[str] = None ) -> List[int]: """ Identifies surface atoms using Voronoi tessellation with periodic boundary conditions. Args: atoms (Atoms): ASE Atoms object with PBC settings. Returns: List[int]: Indices of surface atoms. """ if not SCIPY_INST: print("Scipy isnt installed; get_surface_sites_by_voronoi_pbc wont work") return list(range(len(atoms))) atoms2 = atoms.copy() if rem_symbols: rem_atoms = [i for i, a in enumerate(atoms2) if a.symbol in rem_symbols] del atoms2[rem_atoms] cell = atoms2.cell pbc = atoms2.pbc positions = atoms2.get_positions() extended_positions = [] tags = [] original_indices = [] # Generate offsets based on PBC settings (e.g., [-1, 0, 1] for periodic dimensions) offsets = [] for dim in range(3): if pbc[dim]: offsets.append([-1, 0, 1]) else: offsets.append([0]) # Create all offset combinations (e.g., 3x3x1 for a 2D-periodic slab) offset_combinations = product(*offsets) # Replicate atoms in neighboring cells based on PBC for n_x, n_y, n_z in offset_combinations: shift = n_x * cell[0] + n_y * cell[1] + n_z * cell[2] for i, pos in enumerate(positions): extended_positions.append(pos + shift) tags.append((n_x, n_y, n_z)) original_indices.append(i) # Track original atom index # Compute Voronoi tessellation vor = Voronoi(extended_positions) surface_indices = set() # Check for unbounded Voronoi cells in non-periodic directions for i, region_idx in enumerate(vor.point_region): region = vor.regions[region_idx] # If the region is unbounded (-1 in region), it's a surface atom if -1 in region: # Check if the atom is in the central cell if tags[i] == (0, 0, 0): surface_indices.add(original_indices[i]) return sorted(surface_indices) def get_surface_by_normals( atoms: Atoms, rem_symbols: List[str] = None, surface_normal: float = 0.5, tolerance: float = 1e-5, normalize_final: bool = True, self_interaction: bool = False, bothways: bool = True, ) -> List: """ Compute surface normals for an ASE Atoms object, considering periodic boundaries. Args: atoms (Atoms): ASE Atoms object. surface_normal (float): Threshold for identifying surface atoms based on normal magnitude. normalize_final (bool): Whether to normalize output normals. adsorbate_atoms (list): Indices of adsorbate atoms to exclude. Returns: np.ndarray: Surface normals for each atom. list: Indices of detected surface atoms. """ atoms2 = atoms.copy() if rem_symbols: rem_atoms = [i for i, a in enumerate(atoms2) if a.symbol in rem_symbols] del atoms2[rem_atoms] # Create the graph nl = NeighborList( natural_cutoffs(atoms2), self_interaction=self_interaction, bothways=bothways ) nl.update(atoms2) normals = np.zeros((len(atoms2), 3), dtype=float) for index in range(len(atoms2)): normal = np.zeros(3, dtype=float) atom_pos = atoms2.positions[index] for neighbor, offset in zip(*nl.get_neighbors(index)): neighbor_pos = atoms2.positions[neighbor] + np.dot(offset, atoms2.cell) normal += atom_pos - neighbor_pos # Vector sum of neighbor directions # Store normal vector if it's above threshold if np.linalg.norm(normal) > surface_normal: normals[index, :] = ( normal / np.linalg.norm(normal) if normalize_final else normal ) # Identify surface atoms based on normal magnitude surface = [ index for index in range(len(atoms2)) if np.linalg.norm(normals[index]) > tolerance ] return surface