Module to generate networkx graphs.

Module Contents



Generate a graph object.


Standardize atom_features: subtract mean and divide by std.


Dataset of crystal DGLGraphs.


canonize_edge(src_id, dst_id, src_image, dst_image)

Compute canonical edge representation.

nearest_neighbor_edges([atoms, cutoff, max_neighbors, ...])

Construct k-NN edge list.

build_undirected_edgedata([atoms, edges])

Build undirected graph data from edge set.

prepare_dgl_batch(batch[, device, non_blocking])

Send batched dgl crystal graph to device.

prepare_line_graph_batch(batch[, device, non_blocking])

Send line graph batch to device.


Compute bond angle cosines from bond displacement vectors.

jarvis.core.graphs.canonize_edge(src_id, dst_id, src_image, dst_image)[source]

Compute canonical edge representation.

Sort vertex ids shift periodic images so the first vertex is in (0,0,0) image

jarvis.core.graphs.nearest_neighbor_edges(atoms=None, cutoff=8, max_neighbors=12, id=None, use_canonize=False)[source]

Construct k-NN edge list.

jarvis.core.graphs.build_undirected_edgedata(atoms=None, edges={})[source]

Build undirected graph data from edge set.

edges: dictionary mapping (src_id, dst_id) to set of dst_image r: cartesian displacement vector from src -> dst

class jarvis.core.graphs.Graph(nodes=[], node_attributes=[], edges=[], edge_attributes=[], color_map=None, labels=None)[source]

Bases: object

Generate a graph object.

property num_nodes

Return number of nodes in the graph.

property num_edges

Return number of edges in the graph.

property adjacency_matrix

Provide adjacency_matrix of graph.

static atom_dgl_multigraph(atoms=None, neighbor_strategy='k-nearest', cutoff=8.0, max_neighbors=12, atom_features='cgcnn', max_attempts=3, id: str | None = None, compute_line_graph: bool = True, use_canonize: bool = False)[source]

Obtain a DGLGraph for Atoms object.

static from_atoms(atoms=None, get_prim=False, zero_diag=False, node_atomwise_angle_dist=False, node_atomwise_rdf=False, features='basic', enforce_c_size=10.0, max_n=100, max_cut=5.0, verbose=False, make_colormap=True)[source]

Get Networkx graph. Requires Networkx installation.


atoms: jarvis.core.Atoms object.

rcut: cut-off after which distance will be set to zero

in the adjacency matrix.

features: Node features.

‘atomic_number’: graph with atomic numbers only. ‘cfid’: 438 chemical descriptors from CFID. ‘cgcnn’: hot encoded 92 features. ‘basic’:10 features ‘atomic_fraction’: graph with atomic fractions

in 103 elements.

array: array with CFID chemical descriptor names. See: jarvis/core/specie.py

enforce_c_size: minimum size of the simulation cell in Angst.


Get networkx representation.

classmethod from_dict(d={})[source]

Constuct class from a dictionary.


Provide dictionary representation of the Graph object.


Provide representation during print statements.

class jarvis.core.graphs.Standardize(mean: torch.Tensor, std: torch.Tensor)[source]

Bases: torch.nn.Module

Standardize atom_features: subtract mean and divide by std.

forward(g: dgl.DGLGraph)[source]

Apply standardization to atom_features.

jarvis.core.graphs.prepare_dgl_batch(batch: Tuple[dgl.DGLGraph, torch.Tensor], device=None, non_blocking=False)[source]

Send batched dgl crystal graph to device.

jarvis.core.graphs.prepare_line_graph_batch(batch: Tuple[Tuple[dgl.DGLGraph, dgl.DGLGraph], torch.Tensor], device=None, non_blocking=False)[source]

Send line graph batch to device.

Note: the batch is a nested tuple, with the graph and line graph together


Compute bond angle cosines from bond displacement vectors.

class jarvis.core.graphs.StructureDataset(df: pandas.DataFrame, graphs: Sequence[dgl.DGLGraph], target: str, atom_features='atomic_number', transform=None, line_graph=False, classification=False, id_tag='jid')[source]

Bases: torch.utils.data.Dataset

Dataset of crystal DGLGraphs.

static _get_attribute_lookup(atom_features: str = 'cgcnn')[source]

Build a lookup array indexed by atomic number.


Get length.


Get StructureDataset sample.


Atom-wise feature standardization transform.

static collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]])[source]

Dataloader helper to batch graphs cross samples.

static collate_line_graph(samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]])[source]

Dataloader helper to batch graphs cross samples.