jarvis.core.graphs

Module to generate networkx graphs.

Classes

Graph

Generate a graph object.

Standardize

Standardize atom_features: subtract mean and divide by std.

StructureDataset

Dataset of crystal DGLGraphs.

Functions

canonize_edge(src_id, dst_id, src_image, dst_image)

Compute canonical edge representation.

nearest_neighbor_edges([atoms, cutoff, max_neighbors, ...])

Construct k-NN edge list.

build_undirected_edgedata([atoms, edges])

Build undirected graph data from edge set.

prepare_dgl_batch(batch[, device, non_blocking])

Send batched dgl crystal graph to device.

prepare_line_graph_batch(batch[, device, non_blocking])

Send line graph batch to device.

compute_bond_cosines(edges)

Compute bond angle cosines from bond displacement vectors.

Module Contents

jarvis.core.graphs.canonize_edge(src_id, dst_id, src_image, dst_image)

Compute canonical edge representation.

Sort vertex ids shift periodic images so the first vertex is in (0,0,0) image

jarvis.core.graphs.nearest_neighbor_edges(atoms=None, cutoff=8, max_neighbors=12, id=None, use_canonize=True)

Construct k-NN edge list.

jarvis.core.graphs.build_undirected_edgedata(atoms=None, edges={})

Build undirected graph data from edge set.

edges: dictionary mapping (src_id, dst_id) to set of dst_image r: cartesian displacement vector from src -> dst

class jarvis.core.graphs.Graph(nodes=[], node_attributes=[], edges=[], edge_attributes=[], color_map=None, labels=None)

Bases: object

Generate a graph object.

nodes = []
node_attributes = []
edges = []
edge_attributes = []
color_map = None
labels = None
static atom_dgl_multigraph(atoms=None, neighbor_strategy='k-nearest', cutoff=8.0, max_neighbors=12, atom_features='cgcnn', max_attempts=3, id: str | None = None, compute_line_graph: bool = True, use_canonize: bool = True)

Obtain a DGLGraph for Atoms object.

static from_atoms(atoms=None, get_prim=False, zero_diag=False, node_atomwise_angle_dist=False, node_atomwise_rdf=False, features='basic', enforce_c_size=10.0, max_n=100, max_cut=5.0, verbose=False, make_colormap=True)

Get Networkx graph. Requires Networkx installation.

Args:

atoms: jarvis.core.Atoms object.

rcut: cut-off after which distance will be set to zero

in the adjacency matrix.

features: Node features.

‘atomic_number’: graph with atomic numbers only. ‘cfid’: 438 chemical descriptors from CFID. ‘cgcnn’: hot encoded 92 features. ‘basic’:10 features ‘atomic_fraction’: graph with atomic fractions

in 103 elements.

array: array with CFID chemical descriptor names. See: jarvis/core/specie.py

enforce_c_size: minimum size of the simulation cell in Angst.

to_networkx()

Get networkx representation.

property num_nodes

Return number of nodes in the graph.

property num_edges

Return number of edges in the graph.

classmethod from_dict(d={})

Constuct class from a dictionary.

to_dict()

Provide dictionary representation of the Graph object.

__repr__()

Provide representation during print statements.

property adjacency_matrix

Provide adjacency_matrix of graph.

class jarvis.core.graphs.Standardize(mean: torch.Tensor, std: torch.Tensor)

Bases: torch.nn.Module

Standardize atom_features: subtract mean and divide by std.

mean
std
forward(g: dgl.DGLGraph)

Apply standardization to atom_features.

jarvis.core.graphs.prepare_dgl_batch(batch: Tuple[dgl.DGLGraph, torch.Tensor], device=None, non_blocking=False)

Send batched dgl crystal graph to device.

jarvis.core.graphs.prepare_line_graph_batch(batch: Tuple[Tuple[dgl.DGLGraph, dgl.DGLGraph], torch.Tensor], device=None, non_blocking=False)

Send line graph batch to device.

Note: the batch is a nested tuple, with the graph and line graph together

jarvis.core.graphs.compute_bond_cosines(edges)

Compute bond angle cosines from bond displacement vectors.

class jarvis.core.graphs.StructureDataset(df: pandas.DataFrame, graphs: Sequence[dgl.DGLGraph], target: str, atom_features='atomic_number', transform=None, line_graph=False, classification=False, id_tag='jid')

Bases: torch.utils.data.Dataset

Dataset of crystal DGLGraphs.

df
graphs
target
line_graph = False
labels
ids
transform = None
prepare_batch
static _get_attribute_lookup(atom_features: str = 'cgcnn')

Build a lookup array indexed by atomic number.

__len__()

Get length.

__getitem__(idx)

Get StructureDataset sample.

setup_standardizer(ids)

Atom-wise feature standardization transform.

static collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]])

Dataloader helper to batch graphs cross samples.

static collate_line_graph(samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]])

Dataloader helper to batch graphs cross samples.