"""Module to generate networkx graphs."""
from jarvis.core.atoms import get_supercell_dims
from jarvis.core.specie import Specie
from jarvis.core.utils import random_colors
import numpy as np
import pandas as pd
from collections import OrderedDict
from jarvis.analysis.structure.neighbors import NeighborsAnalysis
from jarvis.core.specie import chem_data, get_node_attributes
# from jarvis.core.atoms import Atoms
from collections import defaultdict
from typing import List, Tuple, Sequence, Optional
try:
import torch
from tqdm import tqdm
import dgl
except Exception as exp:
print("dgl/torch/tqdm is not installed.", exp)
pass
[docs]def canonize_edge(
src_id, dst_id, src_image, dst_image,
):
"""Compute canonical edge representation.
Sort vertex ids
shift periodic images so the first vertex is in (0,0,0) image
"""
# store directed edges src_id <= dst_id
if dst_id < src_id:
src_id, dst_id = dst_id, src_id
src_image, dst_image = dst_image, src_image
# shift periodic images so that src is in (0,0,0) image
if not np.array_equal(src_image, (0, 0, 0)):
shift = src_image
src_image = tuple(np.subtract(src_image, shift))
dst_image = tuple(np.subtract(dst_image, shift))
assert src_image == (0, 0, 0)
return src_id, dst_id, src_image, dst_image
[docs]def nearest_neighbor_edges(
atoms=None, cutoff=8, max_neighbors=12, id=None, use_canonize=False,
):
"""Construct k-NN edge list."""
# returns List[List[Tuple[site, distance, index, image]]]
all_neighbors = atoms.get_all_neighbors(r=cutoff)
# if a site has too few neighbors, increase the cutoff radius
min_nbrs = min(len(neighborlist) for neighborlist in all_neighbors)
attempt = 0
# print ('cutoff=',all_neighbors)
if min_nbrs < max_neighbors:
# print("extending cutoff radius!", attempt, cutoff, id)
lat = atoms.lattice
if cutoff < max(lat.a, lat.b, lat.c):
r_cut = max(lat.a, lat.b, lat.c)
else:
r_cut = 2 * cutoff
attempt += 1
return nearest_neighbor_edges(
atoms=atoms,
use_canonize=use_canonize,
cutoff=r_cut,
max_neighbors=max_neighbors,
id=id,
)
# build up edge list
# NOTE: currently there's no guarantee that this creates undirected graphs
# An undirected solution would build the full edge list where nodes are
# keyed by (index, image), and ensure each edge has a complementary edge
# indeed, JVASP-59628 is an example of a calculation where this produces
# a graph where one site has no incident edges!
# build an edge dictionary u -> v
# so later we can run through the dictionary
# and remove all pairs of edges
# so what's left is the odd ones out
edges = defaultdict(set)
for site_idx, neighborlist in enumerate(all_neighbors):
# sort on distance
neighborlist = sorted(neighborlist, key=lambda x: x[2])
distances = np.array([nbr[2] for nbr in neighborlist])
ids = np.array([nbr[1] for nbr in neighborlist])
images = np.array([nbr[3] for nbr in neighborlist])
# find the distance to the k-th nearest neighbor
max_dist = distances[max_neighbors - 1]
# max_dist = distances[max_neighbors - 1]
# keep all edges out to the neighbor shell of the k-th neighbor
ids = ids[distances <= max_dist]
images = images[distances <= max_dist]
distances = distances[distances <= max_dist]
# keep track of cell-resolved edges
# to enforce undirected graph construction
for dst, image in zip(ids, images):
src_id, dst_id, src_image, dst_image = canonize_edge(
site_idx, dst, (0, 0, 0), tuple(image)
)
if use_canonize:
edges[(src_id, dst_id)].add(dst_image)
else:
edges[(site_idx, dst)].add(tuple(image))
return edges
[docs]def build_undirected_edgedata(
atoms=None, edges={},
):
"""Build undirected graph data from edge set.
edges: dictionary mapping (src_id, dst_id) to set of dst_image
r: cartesian displacement vector from src -> dst
"""
# second pass: construct *undirected* graph
# import pprint
u, v, r = [], [], []
for (src_id, dst_id), images in edges.items():
for dst_image in images:
# fractional coordinate for periodic image of dst
dst_coord = atoms.frac_coords[dst_id] + dst_image
# cartesian displacement vector pointing from src -> dst
d = atoms.lattice.cart_coords(
dst_coord - atoms.frac_coords[src_id]
)
# if np.linalg.norm(d)!=0:
# print ('jv',dst_image,d)
# add edges for both directions
for uu, vv, dd in [(src_id, dst_id, d), (dst_id, src_id, -d)]:
u.append(uu)
v.append(vv)
r.append(dd)
u, v, r = (np.array(x) for x in (u, v, r))
u = torch.tensor(u)
v = torch.tensor(v)
r = torch.tensor(r).type(torch.get_default_dtype())
return u, v, r
[docs]class Graph(object):
"""Generate a graph object."""
def __init__(
self,
nodes=[],
node_attributes=[],
edges=[],
edge_attributes=[],
color_map=None,
labels=None,
):
"""
Initialize the graph object.
Args:
nodes: IDs of the graph nodes as integer array.
node_attributes: node features as multi-dimensional array.
edges: connectivity as a (u,v) pair where u is
the source index and v the destination ID.
edge_attributes: attributes for each connectivity.
as simple as euclidean distances.
"""
self.nodes = nodes
self.node_attributes = node_attributes
self.edges = edges
self.edge_attributes = edge_attributes
self.color_map = color_map
self.labels = labels
[docs] @staticmethod
def atom_dgl_multigraph(
atoms=None,
neighbor_strategy="k-nearest",
cutoff=8.0,
max_neighbors=12,
atom_features="cgcnn",
max_attempts=3,
id: Optional[str] = None,
compute_line_graph: bool = True,
use_canonize: bool = False,
):
"""Obtain a DGLGraph for Atoms object."""
if neighbor_strategy == "k-nearest":
edges = nearest_neighbor_edges(
atoms=atoms,
cutoff=cutoff,
max_neighbors=max_neighbors,
id=id,
use_canonize=use_canonize,
)
else:
raise ValueError("Not implemented yet", neighbor_strategy)
# elif neighbor_strategy == "voronoi":
# edges = voronoi_edges(structure)
u, v, r = build_undirected_edgedata(atoms, edges)
# build up atom attribute tensor
sps_features = []
for ii, s in enumerate(atoms.elements):
feat = list(get_node_attributes(s, atom_features=atom_features))
# if include_prdf_angles:
# feat=feat+list(prdf[ii])+list(adf[ii])
sps_features.append(feat)
sps_features = np.array(sps_features)
node_features = torch.tensor(sps_features).type(
torch.get_default_dtype()
)
g = dgl.graph((u, v))
g.ndata["atom_features"] = node_features
g.edata["r"] = r
if compute_line_graph:
# construct atomistic line graph
# (nodes are bonds, edges are bond pairs)
# and add bond angle cosines as edge features
lg = g.line_graph(shared=True)
lg.apply_edges(compute_bond_cosines)
return g, lg
else:
return g
[docs] @staticmethod
def from_atoms(
atoms=None,
get_prim=False,
zero_diag=False,
node_atomwise_angle_dist=False,
node_atomwise_rdf=False,
features="basic",
enforce_c_size=10.0,
max_n=100,
max_cut=5.0,
verbose=False,
make_colormap=True,
):
"""
Get Networkx graph. Requires Networkx installation.
Args:
atoms: jarvis.core.Atoms object.
rcut: cut-off after which distance will be set to zero
in the adjacency matrix.
features: Node features.
'atomic_number': graph with atomic numbers only.
'cfid': 438 chemical descriptors from CFID.
'cgcnn': hot encoded 92 features.
'basic':10 features
'atomic_fraction': graph with atomic fractions
in 103 elements.
array: array with CFID chemical descriptor names.
See: jarvis/core/specie.py
enforce_c_size: minimum size of the simulation cell in Angst.
"""
if get_prim:
atoms = atoms.get_primitive_atoms
dim = get_supercell_dims(atoms=atoms, enforce_c_size=enforce_c_size)
atoms = atoms.make_supercell(dim)
adj = np.array(atoms.raw_distance_matrix.copy())
# zero out edges with bond length greater than threshold
adj[adj >= max_cut] = 0
if zero_diag:
np.fill_diagonal(adj, 0.0)
nodes = np.arange(atoms.num_atoms)
if features == "atomic_number":
node_attributes = np.array(
[[np.array(Specie(i).Z)] for i in atoms.elements],
dtype="float",
)
if features == "atomic_fraction":
node_attributes = []
fracs = atoms.composition.atomic_fraction_array
for i in fracs:
node_attributes.append(np.array([float(i)]))
node_attributes = np.array(node_attributes)
elif features == "basic":
feats = [
"Z",
"coulmn",
"row",
"X",
"atom_rad",
"nsvalence",
"npvalence",
"ndvalence",
"nfvalence",
"first_ion_en",
"elec_aff",
]
node_attributes = []
for i in atoms.elements:
tmp = []
for j in feats:
tmp.append(Specie(i).element_property(j))
node_attributes.append(tmp)
node_attributes = np.array(node_attributes, dtype="float")
elif features == "cfid":
node_attributes = np.array(
[np.array(Specie(i).get_descrp_arr) for i in atoms.elements],
dtype="float",
)
elif isinstance(features, list):
node_attributes = []
for i in atoms.elements:
tmp = []
for j in features:
tmp.append(Specie(i).element_property(j))
node_attributes.append(tmp)
node_attributes = np.array(node_attributes, dtype="float")
else:
print("Please check the input options.")
if node_atomwise_rdf or node_atomwise_angle_dist:
nbr = NeighborsAnalysis(
atoms, max_n=max_n, verbose=verbose, max_cut=max_cut
)
if node_atomwise_rdf:
node_attributes = np.concatenate(
(node_attributes, nbr.atomwise_radial_dist()), axis=1
)
node_attributes = np.array(node_attributes, dtype="float")
if node_atomwise_angle_dist:
node_attributes = np.concatenate(
(node_attributes, nbr.atomwise_angle_dist()), axis=1
)
node_attributes = np.array(node_attributes, dtype="float")
# construct edge list
uv = []
edge_features = []
for ii, i in enumerate(atoms.elements):
for jj, j in enumerate(atoms.elements):
bondlength = adj[ii, jj]
if bondlength > 0:
uv.append((ii, jj))
edge_features.append(bondlength)
edge_attributes = edge_features
if make_colormap:
sps = atoms.uniq_species
color_dict = random_colors(number_of_colors=len(sps))
new_colors = {}
for i, j in color_dict.items():
new_colors[sps[i]] = j
color_map = []
for ii, i in enumerate(atoms.elements):
color_map.append(new_colors[i])
return Graph(
nodes=nodes,
edges=uv,
node_attributes=np.array(node_attributes),
edge_attributes=np.array(edge_attributes),
color_map=color_map,
)
[docs] def to_networkx(self):
"""Get networkx representation."""
import networkx as nx
graph = nx.DiGraph()
graph.add_nodes_from(self.nodes)
graph.add_edges_from(self.edges)
for i, j in zip(self.edges, self.edge_attributes):
graph.add_edge(i[0], i[1], weight=j)
return graph
@property
def num_nodes(self):
"""Return number of nodes in the graph."""
return len(self.nodes)
@property
def num_edges(self):
"""Return number of edges in the graph."""
return len(self.edges)
[docs] @classmethod
def from_dict(self, d={}):
"""Constuct class from a dictionary."""
return Graph(
nodes=d["nodes"],
edges=d["edges"],
node_attributes=d["node_attributes"],
edge_attributes=d["edge_attributes"],
color_map=d["color_map"],
labels=d["labels"],
)
[docs] def to_dict(self):
"""Provide dictionary representation of the Graph object."""
info = OrderedDict()
info["nodes"] = np.array(self.nodes).tolist()
info["edges"] = np.array(self.edges).tolist()
info["node_attributes"] = np.array(self.node_attributes).tolist()
info["edge_attributes"] = np.array(self.edge_attributes).tolist()
info["color_map"] = np.array(self.color_map).tolist()
info["labels"] = np.array(self.labels).tolist()
return info
[docs] def __repr__(self):
"""Provide representation during print statements."""
return "Graph({})".format(self.to_dict())
@property
def adjacency_matrix(self):
"""Provide adjacency_matrix of graph."""
A = np.zeros((self.num_nodes, self.num_nodes))
for edge, a in zip(self.edges, self.edge_attributes):
A[edge] = a
return A
[docs]class Standardize(torch.nn.Module):
"""Standardize atom_features: subtract mean and divide by std."""
def __init__(self, mean: torch.Tensor, std: torch.Tensor):
"""Register featurewise mean and standard deviation."""
super().__init__()
self.mean = mean
self.std = std
[docs] def forward(self, g: dgl.DGLGraph):
"""Apply standardization to atom_features."""
g = g.local_var()
h = g.ndata.pop("atom_features")
g.ndata["atom_features"] = (h - self.mean) / self.std
return g
[docs]def prepare_dgl_batch(
batch: Tuple[dgl.DGLGraph, torch.Tensor], device=None, non_blocking=False
):
"""Send batched dgl crystal graph to device."""
g, t = batch
batch = (
g.to(device, non_blocking=non_blocking),
t.to(device, non_blocking=non_blocking),
)
return batch
[docs]def prepare_line_graph_batch(
batch: Tuple[Tuple[dgl.DGLGraph, dgl.DGLGraph], torch.Tensor],
device=None,
non_blocking=False,
):
"""Send line graph batch to device.
Note: the batch is a nested tuple, with the graph and line graph together
"""
g, lg, t = batch
batch = (
(
g.to(device, non_blocking=non_blocking),
lg.to(device, non_blocking=non_blocking),
),
t.to(device, non_blocking=non_blocking),
)
return batch
# def prepare_batch(batch, device=None):
# """Send tuple to device, including DGLGraphs."""
# return tuple(x.to(device) for x in batch)
[docs]def compute_bond_cosines(edges):
"""Compute bond angle cosines from bond displacement vectors."""
# line graph edge: (a, b), (b, c)
# `a -> b -> c`
# use law of cosines to compute angles cosines
# negate src bond so displacements are like `a <- b -> c`
# cos(theta) = ba \dot bc / (||ba|| ||bc||)
r1 = -edges.src["r"]
r2 = edges.dst["r"]
bond_cosine = torch.sum(r1 * r2, dim=1) / (
torch.norm(r1, dim=1) * torch.norm(r2, dim=1)
)
bond_cosine = torch.clamp(bond_cosine, -1, 1)
# bond_cosine = torch.arccos((torch.clamp(bond_cosine, -1, 1)))
# print (r1,r1.shape)
# print (r2,r2.shape)
# print (bond_cosine,bond_cosine.shape)
return {"h": bond_cosine}
[docs]class StructureDataset(torch.utils.data.Dataset):
"""Dataset of crystal DGLGraphs."""
def __init__(
self,
df: pd.DataFrame,
graphs: Sequence[dgl.DGLGraph],
target: str,
atom_features="atomic_number",
transform=None,
line_graph=False,
classification=False,
id_tag="jid",
):
"""Pytorch Dataset for atomistic graphs.
`df`: pandas dataframe from e.g. jarvis.db.figshare.data
`graphs`: DGLGraph representations corresponding to rows in `df`
`target`: key for label column in `df`
"""
self.df = df
self.graphs = graphs
self.target = target
self.line_graph = line_graph
self.labels = self.df[target]
self.ids = self.df[id_tag]
self.labels = torch.tensor(self.df[target]).type(
torch.get_default_dtype()
)
self.transform = transform
features = self._get_attribute_lookup(atom_features)
# load selected node representation
# assume graphs contain atomic number in g.ndata["atom_features"]
for g in graphs:
z = g.ndata.pop("atom_features")
g.ndata["atomic_number"] = z
z = z.type(torch.IntTensor).squeeze()
f = torch.tensor(features[z]).type(torch.FloatTensor)
if g.num_nodes() == 1:
f = f.unsqueeze(0)
g.ndata["atom_features"] = f
self.prepare_batch = prepare_dgl_batch
if line_graph:
self.prepare_batch = prepare_line_graph_batch
print("building line graphs")
self.line_graphs = []
for g in tqdm(graphs):
lg = g.line_graph(shared=True)
lg.apply_edges(compute_bond_cosines)
self.line_graphs.append(lg)
if classification:
self.labels = self.labels.view(-1).long()
print("Classification dataset.", self.labels)
[docs] @staticmethod
def _get_attribute_lookup(atom_features: str = "cgcnn"):
"""Build a lookup array indexed by atomic number."""
max_z = max(v["Z"] for v in chem_data.values())
# get feature shape (referencing Carbon)
template = get_node_attributes("C", atom_features)
features = np.zeros((1 + max_z, len(template)))
for element, v in chem_data.items():
z = v["Z"]
x = get_node_attributes(element, atom_features)
if x is not None:
features[z, :] = x
return features
[docs] def __len__(self):
"""Get length."""
return self.labels.shape[0]
[docs] def __getitem__(self, idx):
"""Get StructureDataset sample."""
g = self.graphs[idx]
label = self.labels[idx]
if self.transform:
g = self.transform(g)
if self.line_graph:
return g, self.line_graphs[idx], label
return g, label
[docs] def setup_standardizer(self, ids):
"""Atom-wise feature standardization transform."""
x = torch.cat(
[
g.ndata["atom_features"]
for idx, g in enumerate(self.graphs)
if idx in ids
]
)
self.atom_feature_mean = x.mean(0)
self.atom_feature_std = x.std(0)
self.transform = Standardize(
self.atom_feature_mean, self.atom_feature_std
)
[docs] @staticmethod
def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]):
"""Dataloader helper to batch graphs cross `samples`."""
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels)
[docs] @staticmethod
def collate_line_graph(
samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]]
):
"""Dataloader helper to batch graphs cross `samples`."""
graphs, line_graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
batched_line_graph = dgl.batch(line_graphs)
if len(labels[0].size()) > 0:
return batched_graph, batched_line_graph, torch.stack(labels)
else:
return batched_graph, batched_line_graph, torch.tensor(labels)
"""
if __name__ == "__main__":
from jarvis.core.atoms import Atoms
from jarvis.db.figshare import get_jid_data
atoms = Atoms.from_dict(get_jid_data("JVASP-664")["atoms"])
g = Graph.from_atoms(
atoms=atoms,
features="basic",
get_prim=True,
zero_diag=True,
node_atomwise_angle_dist=True,
node_atomwise_rdf=True,
)
g = Graph.from_atoms(
atoms=atoms,
features="cfid",
get_prim=True,
zero_diag=True,
node_atomwise_angle_dist=True,
node_atomwise_rdf=True,
)
g = Graph.from_atoms(
atoms=atoms,
features="atomic_number",
get_prim=True,
zero_diag=True,
node_atomwise_angle_dist=True,
node_atomwise_rdf=True,
)
g = Graph.from_atoms(atoms=atoms, features="basic")
g = Graph.from_atoms(
atoms=atoms, features=["Z", "atom_mass", "max_oxid_s"]
)
g = Graph.from_atoms(atoms=atoms, features="cfid")
# print(g)
d = g.to_dict()
g = Graph.from_dict(d)
num_nodes = g.num_nodes
num_edges = g.num_edges
print(num_nodes, num_edges)
assert num_nodes == 48
assert num_edges == 2304
assert len(g.adjacency_matrix) == 2304
# graph, color_map = get_networkx_graph(atoms)
# nx.draw(graph, node_color=color_map, with_labels=True)
# from jarvis.analysis.structure.neighbors import NeighborsAnalysis
"""