Source code for jarvis.core.graphs

"""Module to generate networkx graphs."""
from jarvis.core.atoms import get_supercell_dims
from jarvis.core.specie import Specie
from jarvis.core.utils import random_colors
import numpy as np
import pandas as pd
from collections import OrderedDict
from jarvis.analysis.structure.neighbors import NeighborsAnalysis
from jarvis.core.specie import chem_data, get_node_attributes

# from jarvis.core.atoms import Atoms
from collections import defaultdict
from typing import List, Tuple, Sequence, Optional

    import torch
    from tqdm import tqdm
    import dgl
except Exception as exp:
    print("dgl/torch/tqdm is not installed.", exp)

[docs]def canonize_edge( src_id, dst_id, src_image, dst_image, ): """Compute canonical edge representation. Sort vertex ids shift periodic images so the first vertex is in (0,0,0) image """ # store directed edges src_id <= dst_id if dst_id < src_id: src_id, dst_id = dst_id, src_id src_image, dst_image = dst_image, src_image # shift periodic images so that src is in (0,0,0) image if not np.array_equal(src_image, (0, 0, 0)): shift = src_image src_image = tuple(np.subtract(src_image, shift)) dst_image = tuple(np.subtract(dst_image, shift)) assert src_image == (0, 0, 0) return src_id, dst_id, src_image, dst_image
[docs]def nearest_neighbor_edges( atoms=None, cutoff=8, max_neighbors=12, id=None, use_canonize=False, ): """Construct k-NN edge list.""" # returns List[List[Tuple[site, distance, index, image]]] all_neighbors = atoms.get_all_neighbors(r=cutoff) # if a site has too few neighbors, increase the cutoff radius min_nbrs = min(len(neighborlist) for neighborlist in all_neighbors) attempt = 0 # print ('cutoff=',all_neighbors) if min_nbrs < max_neighbors: # print("extending cutoff radius!", attempt, cutoff, id) lat = atoms.lattice if cutoff < max(lat.a, lat.b, lat.c): r_cut = max(lat.a, lat.b, lat.c) else: r_cut = 2 * cutoff attempt += 1 return nearest_neighbor_edges( atoms=atoms, use_canonize=use_canonize, cutoff=r_cut, max_neighbors=max_neighbors, id=id, ) # build up edge list # NOTE: currently there's no guarantee that this creates undirected graphs # An undirected solution would build the full edge list where nodes are # keyed by (index, image), and ensure each edge has a complementary edge # indeed, JVASP-59628 is an example of a calculation where this produces # a graph where one site has no incident edges! # build an edge dictionary u -> v # so later we can run through the dictionary # and remove all pairs of edges # so what's left is the odd ones out edges = defaultdict(set) for site_idx, neighborlist in enumerate(all_neighbors): # sort on distance neighborlist = sorted(neighborlist, key=lambda x: x[2]) distances = np.array([nbr[2] for nbr in neighborlist]) ids = np.array([nbr[1] for nbr in neighborlist]) images = np.array([nbr[3] for nbr in neighborlist]) # find the distance to the k-th nearest neighbor max_dist = distances[max_neighbors - 1] # max_dist = distances[max_neighbors - 1] # keep all edges out to the neighbor shell of the k-th neighbor ids = ids[distances <= max_dist] images = images[distances <= max_dist] distances = distances[distances <= max_dist] # keep track of cell-resolved edges # to enforce undirected graph construction for dst, image in zip(ids, images): src_id, dst_id, src_image, dst_image = canonize_edge( site_idx, dst, (0, 0, 0), tuple(image) ) if use_canonize: edges[(src_id, dst_id)].add(dst_image) else: edges[(site_idx, dst)].add(tuple(image)) return edges
[docs]def build_undirected_edgedata( atoms=None, edges={}, ): """Build undirected graph data from edge set. edges: dictionary mapping (src_id, dst_id) to set of dst_image r: cartesian displacement vector from src -> dst """ # second pass: construct *undirected* graph # import pprint u, v, r = [], [], [] for (src_id, dst_id), images in edges.items(): for dst_image in images: # fractional coordinate for periodic image of dst dst_coord = atoms.frac_coords[dst_id] + dst_image # cartesian displacement vector pointing from src -> dst d = atoms.lattice.cart_coords( dst_coord - atoms.frac_coords[src_id] ) # if np.linalg.norm(d)!=0: # print ('jv',dst_image,d) # add edges for both directions for uu, vv, dd in [(src_id, dst_id, d), (dst_id, src_id, -d)]: u.append(uu) v.append(vv) r.append(dd) u, v, r = (np.array(x) for x in (u, v, r)) u = torch.tensor(u) v = torch.tensor(v) r = torch.tensor(r).type(torch.get_default_dtype()) return u, v, r
[docs]class Graph(object): """Generate a graph object.""" def __init__( self, nodes=[], node_attributes=[], edges=[], edge_attributes=[], color_map=None, labels=None, ): """ Initialize the graph object. Args: nodes: IDs of the graph nodes as integer array. node_attributes: node features as multi-dimensional array. edges: connectivity as a (u,v) pair where u is the source index and v the destination ID. edge_attributes: attributes for each connectivity. as simple as euclidean distances. """ self.nodes = nodes self.node_attributes = node_attributes self.edges = edges self.edge_attributes = edge_attributes self.color_map = color_map self.labels = labels
[docs] @staticmethod def atom_dgl_multigraph( atoms=None, neighbor_strategy="k-nearest", cutoff=8.0, max_neighbors=12, atom_features="cgcnn", max_attempts=3, id: Optional[str] = None, compute_line_graph: bool = True, use_canonize: bool = False, ): """Obtain a DGLGraph for Atoms object.""" if neighbor_strategy == "k-nearest": edges = nearest_neighbor_edges( atoms=atoms, cutoff=cutoff, max_neighbors=max_neighbors, id=id, use_canonize=use_canonize, ) else: raise ValueError("Not implemented yet", neighbor_strategy) # elif neighbor_strategy == "voronoi": # edges = voronoi_edges(structure) u, v, r = build_undirected_edgedata(atoms, edges) # build up atom attribute tensor sps_features = [] for ii, s in enumerate(atoms.elements): feat = list(get_node_attributes(s, atom_features=atom_features)) # if include_prdf_angles: # feat=feat+list(prdf[ii])+list(adf[ii]) sps_features.append(feat) sps_features = np.array(sps_features) node_features = torch.tensor(sps_features).type( torch.get_default_dtype() ) g = dgl.graph((u, v)) g.ndata["atom_features"] = node_features g.edata["r"] = r if compute_line_graph: # construct atomistic line graph # (nodes are bonds, edges are bond pairs) # and add bond angle cosines as edge features lg = g.line_graph(shared=True) lg.apply_edges(compute_bond_cosines) return g, lg else: return g
[docs] @staticmethod def from_atoms( atoms=None, get_prim=False, zero_diag=False, node_atomwise_angle_dist=False, node_atomwise_rdf=False, features="basic", enforce_c_size=10.0, max_n=100, max_cut=5.0, verbose=False, make_colormap=True, ): """ Get Networkx graph. Requires Networkx installation. Args: atoms: jarvis.core.Atoms object. rcut: cut-off after which distance will be set to zero in the adjacency matrix. features: Node features. 'atomic_number': graph with atomic numbers only. 'cfid': 438 chemical descriptors from CFID. 'cgcnn': hot encoded 92 features. 'basic':10 features 'atomic_fraction': graph with atomic fractions in 103 elements. array: array with CFID chemical descriptor names. See: jarvis/core/ enforce_c_size: minimum size of the simulation cell in Angst. """ if get_prim: atoms = atoms.get_primitive_atoms dim = get_supercell_dims(atoms=atoms, enforce_c_size=enforce_c_size) atoms = atoms.make_supercell(dim) adj = np.array(atoms.raw_distance_matrix.copy()) # zero out edges with bond length greater than threshold adj[adj >= max_cut] = 0 if zero_diag: np.fill_diagonal(adj, 0.0) nodes = np.arange(atoms.num_atoms) if features == "atomic_number": node_attributes = np.array( [[np.array(Specie(i).Z)] for i in atoms.elements], dtype="float", ) if features == "atomic_fraction": node_attributes = [] fracs = atoms.composition.atomic_fraction_array for i in fracs: node_attributes.append(np.array([float(i)])) node_attributes = np.array(node_attributes) elif features == "basic": feats = [ "Z", "coulmn", "row", "X", "atom_rad", "nsvalence", "npvalence", "ndvalence", "nfvalence", "first_ion_en", "elec_aff", ] node_attributes = [] for i in atoms.elements: tmp = [] for j in feats: tmp.append(Specie(i).element_property(j)) node_attributes.append(tmp) node_attributes = np.array(node_attributes, dtype="float") elif features == "cfid": node_attributes = np.array( [np.array(Specie(i).get_descrp_arr) for i in atoms.elements], dtype="float", ) elif isinstance(features, list): node_attributes = [] for i in atoms.elements: tmp = [] for j in features: tmp.append(Specie(i).element_property(j)) node_attributes.append(tmp) node_attributes = np.array(node_attributes, dtype="float") else: print("Please check the input options.") if node_atomwise_rdf or node_atomwise_angle_dist: nbr = NeighborsAnalysis( atoms, max_n=max_n, verbose=verbose, max_cut=max_cut ) if node_atomwise_rdf: node_attributes = np.concatenate( (node_attributes, nbr.atomwise_radial_dist()), axis=1 ) node_attributes = np.array(node_attributes, dtype="float") if node_atomwise_angle_dist: node_attributes = np.concatenate( (node_attributes, nbr.atomwise_angle_dist()), axis=1 ) node_attributes = np.array(node_attributes, dtype="float") # construct edge list uv = [] edge_features = [] for ii, i in enumerate(atoms.elements): for jj, j in enumerate(atoms.elements): bondlength = adj[ii, jj] if bondlength > 0: uv.append((ii, jj)) edge_features.append(bondlength) edge_attributes = edge_features if make_colormap: sps = atoms.uniq_species color_dict = random_colors(number_of_colors=len(sps)) new_colors = {} for i, j in color_dict.items(): new_colors[sps[i]] = j color_map = [] for ii, i in enumerate(atoms.elements): color_map.append(new_colors[i]) return Graph( nodes=nodes, edges=uv, node_attributes=np.array(node_attributes), edge_attributes=np.array(edge_attributes), color_map=color_map, )
[docs] def to_networkx(self): """Get networkx representation.""" import networkx as nx graph = nx.DiGraph() graph.add_nodes_from(self.nodes) graph.add_edges_from(self.edges) for i, j in zip(self.edges, self.edge_attributes): graph.add_edge(i[0], i[1], weight=j) return graph
@property def num_nodes(self): """Return number of nodes in the graph.""" return len(self.nodes) @property def num_edges(self): """Return number of edges in the graph.""" return len(self.edges)
[docs] @classmethod def from_dict(self, d={}): """Constuct class from a dictionary.""" return Graph( nodes=d["nodes"], edges=d["edges"], node_attributes=d["node_attributes"], edge_attributes=d["edge_attributes"], color_map=d["color_map"], labels=d["labels"], )
[docs] def to_dict(self): """Provide dictionary representation of the Graph object.""" info = OrderedDict() info["nodes"] = np.array(self.nodes).tolist() info["edges"] = np.array(self.edges).tolist() info["node_attributes"] = np.array(self.node_attributes).tolist() info["edge_attributes"] = np.array(self.edge_attributes).tolist() info["color_map"] = np.array(self.color_map).tolist() info["labels"] = np.array(self.labels).tolist() return info
[docs] def __repr__(self): """Provide representation during print statements.""" return "Graph({})".format(self.to_dict())
@property def adjacency_matrix(self): """Provide adjacency_matrix of graph.""" A = np.zeros((self.num_nodes, self.num_nodes)) for edge, a in zip(self.edges, self.edge_attributes): A[edge] = a return A
[docs]class Standardize(torch.nn.Module): """Standardize atom_features: subtract mean and divide by std.""" def __init__(self, mean: torch.Tensor, std: torch.Tensor): """Register featurewise mean and standard deviation.""" super().__init__() self.mean = mean self.std = std
[docs] def forward(self, g: dgl.DGLGraph): """Apply standardization to atom_features.""" g = g.local_var() h = g.ndata.pop("atom_features") g.ndata["atom_features"] = (h - self.mean) / self.std return g
[docs]def prepare_dgl_batch( batch: Tuple[dgl.DGLGraph, torch.Tensor], device=None, non_blocking=False ): """Send batched dgl crystal graph to device.""" g, t = batch batch = (, non_blocking=non_blocking),, non_blocking=non_blocking), ) return batch
[docs]def prepare_line_graph_batch( batch: Tuple[Tuple[dgl.DGLGraph, dgl.DGLGraph], torch.Tensor], device=None, non_blocking=False, ): """Send line graph batch to device. Note: the batch is a nested tuple, with the graph and line graph together """ g, lg, t = batch batch = ( (, non_blocking=non_blocking),, non_blocking=non_blocking), ),, non_blocking=non_blocking), ) return batch
# def prepare_batch(batch, device=None): # """Send tuple to device, including DGLGraphs.""" # return tuple( for x in batch)
[docs]def compute_bond_cosines(edges): """Compute bond angle cosines from bond displacement vectors.""" # line graph edge: (a, b), (b, c) # `a -> b -> c` # use law of cosines to compute angles cosines # negate src bond so displacements are like `a <- b -> c` # cos(theta) = ba \dot bc / (||ba|| ||bc||) r1 = -edges.src["r"] r2 = edges.dst["r"] bond_cosine = torch.sum(r1 * r2, dim=1) / ( torch.norm(r1, dim=1) * torch.norm(r2, dim=1) ) bond_cosine = torch.clamp(bond_cosine, -1, 1) # bond_cosine = torch.arccos((torch.clamp(bond_cosine, -1, 1))) # print (r1,r1.shape) # print (r2,r2.shape) # print (bond_cosine,bond_cosine.shape) return {"h": bond_cosine}
[docs]class StructureDataset( """Dataset of crystal DGLGraphs.""" def __init__( self, df: pd.DataFrame, graphs: Sequence[dgl.DGLGraph], target: str, atom_features="atomic_number", transform=None, line_graph=False, classification=False, id_tag="jid", ): """Pytorch Dataset for atomistic graphs. `df`: pandas dataframe from e.g. `graphs`: DGLGraph representations corresponding to rows in `df` `target`: key for label column in `df` """ self.df = df self.graphs = graphs = target self.line_graph = line_graph self.labels = self.df[target] self.ids = self.df[id_tag] self.labels = torch.tensor(self.df[target]).type( torch.get_default_dtype() ) self.transform = transform features = self._get_attribute_lookup(atom_features) # load selected node representation # assume graphs contain atomic number in g.ndata["atom_features"] for g in graphs: z = g.ndata.pop("atom_features") g.ndata["atomic_number"] = z z = z.type(torch.IntTensor).squeeze() f = torch.tensor(features[z]).type(torch.FloatTensor) if g.num_nodes() == 1: f = f.unsqueeze(0) g.ndata["atom_features"] = f self.prepare_batch = prepare_dgl_batch if line_graph: self.prepare_batch = prepare_line_graph_batch print("building line graphs") self.line_graphs = [] for g in tqdm(graphs): lg = g.line_graph(shared=True) lg.apply_edges(compute_bond_cosines) self.line_graphs.append(lg) if classification: self.labels = self.labels.view(-1).long() print("Classification dataset.", self.labels)
[docs] @staticmethod def _get_attribute_lookup(atom_features: str = "cgcnn"): """Build a lookup array indexed by atomic number.""" max_z = max(v["Z"] for v in chem_data.values()) # get feature shape (referencing Carbon) template = get_node_attributes("C", atom_features) features = np.zeros((1 + max_z, len(template))) for element, v in chem_data.items(): z = v["Z"] x = get_node_attributes(element, atom_features) if x is not None: features[z, :] = x return features
[docs] def __len__(self): """Get length.""" return self.labels.shape[0]
[docs] def __getitem__(self, idx): """Get StructureDataset sample.""" g = self.graphs[idx] label = self.labels[idx] if self.transform: g = self.transform(g) if self.line_graph: return g, self.line_graphs[idx], label return g, label
[docs] def setup_standardizer(self, ids): """Atom-wise feature standardization transform.""" x = [ g.ndata["atom_features"] for idx, g in enumerate(self.graphs) if idx in ids ] ) self.atom_feature_mean = x.mean(0) self.atom_feature_std = x.std(0) self.transform = Standardize( self.atom_feature_mean, self.atom_feature_std )
[docs] @staticmethod def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]): """Dataloader helper to batch graphs cross `samples`.""" graphs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) return batched_graph, torch.tensor(labels)
[docs] @staticmethod def collate_line_graph( samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]] ): """Dataloader helper to batch graphs cross `samples`.""" graphs, line_graphs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) batched_line_graph = dgl.batch(line_graphs) if len(labels[0].size()) > 0: return batched_graph, batched_line_graph, torch.stack(labels) else: return batched_graph, batched_line_graph, torch.tensor(labels)
""" if __name__ == "__main__": from jarvis.core.atoms import Atoms from jarvis.db.figshare import get_jid_data atoms = Atoms.from_dict(get_jid_data("JVASP-664")["atoms"]) g = Graph.from_atoms( atoms=atoms, features="basic", get_prim=True, zero_diag=True, node_atomwise_angle_dist=True, node_atomwise_rdf=True, ) g = Graph.from_atoms( atoms=atoms, features="cfid", get_prim=True, zero_diag=True, node_atomwise_angle_dist=True, node_atomwise_rdf=True, ) g = Graph.from_atoms( atoms=atoms, features="atomic_number", get_prim=True, zero_diag=True, node_atomwise_angle_dist=True, node_atomwise_rdf=True, ) g = Graph.from_atoms(atoms=atoms, features="basic") g = Graph.from_atoms( atoms=atoms, features=["Z", "atom_mass", "max_oxid_s"] ) g = Graph.from_atoms(atoms=atoms, features="cfid") # print(g) d = g.to_dict() g = Graph.from_dict(d) num_nodes = g.num_nodes num_edges = g.num_edges print(num_nodes, num_edges) assert num_nodes == 48 assert num_edges == 2304 assert len(g.adjacency_matrix) == 2304 # graph, color_map = get_networkx_graph(atoms) # nx.draw(graph, node_color=color_map, with_labels=True) # from jarvis.analysis.structure.neighbors import NeighborsAnalysis """