jarvis.core.graphs¶
Module to generate networkx graphs.
Classes¶
Generate a graph object. |
|
Standardize atom_features: subtract mean and divide by std. |
|
Dataset of crystal DGLGraphs. |
Functions¶
|
Compute canonical edge representation. |
|
Construct k-NN edge list. |
|
Build undirected graph data from edge set. |
|
Send batched dgl crystal graph to device. |
|
Send line graph batch to device. |
|
Compute bond angle cosines from bond displacement vectors. |
Module Contents¶
- jarvis.core.graphs.canonize_edge(src_id, dst_id, src_image, dst_image)¶
Compute canonical edge representation.
Sort vertex ids shift periodic images so the first vertex is in (0,0,0) image
- jarvis.core.graphs.nearest_neighbor_edges(atoms=None, cutoff=8, max_neighbors=12, id=None, use_canonize=True)¶
Construct k-NN edge list.
- jarvis.core.graphs.build_undirected_edgedata(atoms=None, edges={})¶
Build undirected graph data from edge set.
edges: dictionary mapping (src_id, dst_id) to set of dst_image r: cartesian displacement vector from src -> dst
- class jarvis.core.graphs.Graph(nodes=[], node_attributes=[], edges=[], edge_attributes=[], color_map=None, labels=None)¶
Bases:
objectGenerate a graph object.
- nodes = []¶
- node_attributes = []¶
- edges = []¶
- edge_attributes = []¶
- color_map = None¶
- labels = None¶
- static atom_dgl_multigraph(atoms=None, neighbor_strategy='k-nearest', cutoff=8.0, max_neighbors=12, atom_features='cgcnn', max_attempts=3, id: str | None = None, compute_line_graph: bool = True, use_canonize: bool = True)¶
Obtain a DGLGraph for Atoms object.
- static from_atoms(atoms=None, get_prim=False, zero_diag=False, node_atomwise_angle_dist=False, node_atomwise_rdf=False, features='basic', enforce_c_size=10.0, max_n=100, max_cut=5.0, verbose=False, make_colormap=True)¶
Get Networkx graph. Requires Networkx installation.
- Args:
atoms: jarvis.core.Atoms object.
- rcut: cut-off after which distance will be set to zero
in the adjacency matrix.
- features: Node features.
‘atomic_number’: graph with atomic numbers only. ‘cfid’: 438 chemical descriptors from CFID. ‘cgcnn’: hot encoded 92 features. ‘basic’:10 features ‘atomic_fraction’: graph with atomic fractions
in 103 elements.
array: array with CFID chemical descriptor names. See: jarvis/core/specie.py
enforce_c_size: minimum size of the simulation cell in Angst.
- to_networkx()¶
Get networkx representation.
- property num_nodes¶
Return number of nodes in the graph.
- property num_edges¶
Return number of edges in the graph.
- classmethod from_dict(d={})¶
Constuct class from a dictionary.
- to_dict()¶
Provide dictionary representation of the Graph object.
- __repr__()¶
Provide representation during print statements.
- property adjacency_matrix¶
Provide adjacency_matrix of graph.
- class jarvis.core.graphs.Standardize(mean: torch.Tensor, std: torch.Tensor)¶
Bases:
torch.nn.ModuleStandardize atom_features: subtract mean and divide by std.
- mean¶
- std¶
- forward(g: dgl.DGLGraph)¶
Apply standardization to atom_features.
- jarvis.core.graphs.prepare_dgl_batch(batch: Tuple[dgl.DGLGraph, torch.Tensor], device=None, non_blocking=False)¶
Send batched dgl crystal graph to device.
- jarvis.core.graphs.prepare_line_graph_batch(batch: Tuple[Tuple[dgl.DGLGraph, dgl.DGLGraph], torch.Tensor], device=None, non_blocking=False)¶
Send line graph batch to device.
Note: the batch is a nested tuple, with the graph and line graph together
- jarvis.core.graphs.compute_bond_cosines(edges)¶
Compute bond angle cosines from bond displacement vectors.
- class jarvis.core.graphs.StructureDataset(df: pandas.DataFrame, graphs: Sequence[dgl.DGLGraph], target: str, atom_features='atomic_number', transform=None, line_graph=False, classification=False, id_tag='jid')¶
Bases:
torch.utils.data.DatasetDataset of crystal DGLGraphs.
- df¶
- graphs¶
- target¶
- line_graph = False¶
- labels¶
- ids¶
- transform = None¶
- prepare_batch¶
- static _get_attribute_lookup(atom_features: str = 'cgcnn')¶
Build a lookup array indexed by atomic number.
- __len__()¶
Get length.
- __getitem__(idx)¶
Get StructureDataset sample.
- setup_standardizer(ids)¶
Atom-wise feature standardization transform.
- static collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]])¶
Dataloader helper to batch graphs cross samples.
- static collate_line_graph(samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]])¶
Dataloader helper to batch graphs cross samples.