"""Set of useful utility functions."""
from collections import OrderedDict
from collections import defaultdict
import random
import numpy as np
import math
import xmltodict
import re
import requests
[docs]def xml_to_dict(fname):
"""Parse XML file."""
with open(fname, "r") as f:
data = xmltodict.parse(f.read())
return data
[docs]def get_factors(x, start=2):
"""Get factors of a number."""
facts = []
for i in range(start, x + 1):
if x % i == 0:
facts.append(i)
return facts
[docs]def get_counts(array=["W", "W", "Mo", "Mo", "S", "S"]):
"""
Get number of unique elements and their counts.
Uses OrderedDict.
Args:
array of elements
Returns:
ordereddict, e.g.OrderedDict([('W', 2), ('Mo', 2), ('S', 2)])
"""
uniqe_els = []
for i in array:
if i not in uniqe_els:
uniqe_els.append(i)
info = OrderedDict()
for i in uniqe_els:
info.setdefault(i, 0)
for i in array:
info[i] += 1
return info
[docs]def gcd(a, b):
"""Calculate the Greatest Common Divisor of a and b.
Unless b==0, the result will have the same sign as b (so that when
b is divided by it, the result comes out positive).
"""
while b:
a, b = b, a % b
return a
[docs]def ext_gcd(a, b):
"""GCD module from ase."""
if b == 0:
return 1, 0
elif a % b == 0:
return 0, 1
else:
x, y = ext_gcd(b, a % b)
return y, x - y * (a // b)
[docs]def rand_select(x=[]):
"""Select randomly with index info."""
info = {}
for i, ii in enumerate(x):
info.setdefault(ii, []).append(i)
selected = {}
for i, j in info.items():
chosen = random.choice(j)
selected.setdefault(i, chosen)
return selected
[docs]def rec_dict():
"""Make a recursion dictionary."""
return defaultdict(rec_dict)
[docs]def random_colors(number_of_colors=110):
"""Generate random colors for atom coloring."""
colors = [
"#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
for i in range(number_of_colors)
]
color_dict = {}
for i, ii in enumerate(colors):
color_dict[i] = ii
return color_dict
[docs]def get_angle(
a=np.array([1, 2, 3]), b=np.array([4, 5, 6]), c=np.array([7, 8, 9])
):
"""Get angle between three vectors."""
# theta = argcos(x.y/(|x||y|))
cos = np.dot((a - b), (c - b)) / (
np.linalg.norm((a - b)) * np.linalg.norm((c - b))
)
if cos <= -1.0:
cos = cos + 0.000001
if cos >= 1.0:
cos = cos - 0.000001
angle = math.degrees(math.acos(cos))
return angle
[docs]def recast_array_on_uniq_array_elements(
uniq=["Si", "Al", "O"],
arr=["Si", "Si", "Al", "Al", "Si", "O", "O", "O", "O"],
):
"""Recast array on uniq array elements."""
info = {}
for i, ii in enumerate(uniq):
for j, jj in enumerate(arr):
if ii == jj:
info.setdefault(ii, []).append(j)
return info
[docs]def lorentzian(x, y, x0, gamma):
"""Get Lorentzian of a function."""
return (y / math.pi) * (
(0.5 * gamma) / ((x - x0) ** 2 + (0.5 * gamma) ** 2)
)
# color_dict=random_colors()
[docs]def stringdict_to_xml(d={}, enforce_string=False):
"""Convert string dictionary to XML."""
line = ""
for i, j in d.items():
if enforce_string:
line += "<" + str(i) + ">'" + str(j) + "'</" + str(i) + ">"
else:
line += "<" + str(i) + ">" + str(j) + "</" + str(i) + ">"
return line
[docs]def array_to_string(arr=[]):
"""Convert 1D arry to string."""
return ",".join(map(str, arr))
[docs]def chunks(lst, n):
"""Split successive n-sized chunks from list."""
x = []
for i in range(0, len(lst), n):
x.append(lst[i : i + n])
return x
[docs]def check_match(a, b, tol=1e-8):
"""Check if a and b are the same, taking into account PBCs."""
if abs(a[0] - b[0]) < tol or abs(abs(a[0] - b[0]) - 1) < tol:
if abs(a[1] - b[1]) < tol or abs(abs(a[1] - b[1]) - 1) < tol:
if abs(a[2] - b[2]) < tol or abs(abs(a[2] - b[2]) - 1) < tol:
return True
return False
[docs]def update_dict(main={}, extra={}):
"""Return update dictionary."""
# Helper function for dict.update method
tmp = main.copy()
for i, j in extra.items():
tmp[i] = j
return tmp
[docs]def get_new_coord_for_xyz_sym(frac_coord=[], xyz_string=""):
"""Obtain new coord from xyz string."""
affine_matrix = parse_xyz_string(xyz_string)
coord = operate_affine(frac_coord, affine_matrix)
coord = np.array([i - math.floor(i) for i in coord])
return coord
[docs]def check_duplicate_coords(coords=[], coord=[]):
"""Check if a coordinate exists."""
positive = False
for i in coords:
tmp = check_match(i, coord)
if tmp:
positive = True
return positive
[docs]def parse_xyz_string(xyz_string):
"""
Convert xyz info to translation and rotation vectors.
Adapted from pymatgen.
Args:
xyz_string: string of the form 'x, y, z', '-x, -y, z',
'-2y+1/2, 3x+1/2, z-y+1/2', etc.
Returns:
translation and rotation vectors.
"""
rot_matrix = np.zeros((3, 3))
trans = np.zeros(3)
toks = xyz_string.strip().replace(" ", "").lower().split(",")
re_rot = re.compile(r"([+-]?)([\d\.]*)/?([\d\.]*)([x-z])")
re_trans = re.compile(r"([+-]?)([\d\.]+)/?([\d\.]*)(?![x-z])")
for i, tok in enumerate(toks):
# build the rotation matrix
for m in re_rot.finditer(tok):
factor = -1 if m.group(1) == "-" else 1
if m.group(2) != "":
factor *= (
float(m.group(2)) / float(m.group(3))
if m.group(3) != ""
else float(m.group(2))
)
j = ord(m.group(4)) - 120
rot_matrix[i, j] = factor
# build the translation vector
for m in re_trans.finditer(tok):
factor = -1 if m.group(1) == "-" else 1
num = (
float(m.group(2)) / float(m.group(3))
if m.group(3) != ""
else float(m.group(2))
)
trans[i] = num * factor
affine_matrix = np.eye(4)
affine_matrix[0:3][:, 0:3] = rot_matrix
affine_matrix[0:3][:, 3] = trans
return affine_matrix
[docs]def operate_affine(cart_coord=[], affine_matrix=[]):
"""Operate affine method."""
affine_point = np.array([cart_coord[0], cart_coord[1], cart_coord[2], 1])
return np.dot(np.array(affine_matrix), affine_point)[0:3]
[docs]def gaussian(x, sigma):
"""Get Gaussian profile."""
return np.exp(-(x ** 2) / (2 * sigma ** 2))
[docs]def lorentzian2(x, gamma):
"""Get Lorentziann profile."""
return (
gamma
/ 2
/ (np.pi * (x ** 2 + (gamma / 2) ** 2))
/ (2 / (np.pi * gamma))
)
[docs]def digitize_array(values=[], max_len=10):
"""Digitze an array."""
has_float = False in [float(i).is_integer() for i in values]
if has_float:
arr = np.array([float(i) for i in values])
max_val = max(arr)
min_val = min(arr)
bins = np.arange(1, max_len + 1) * (max_val - min_val) / 10
values = np.digitize(arr, bins)
return values
[docs]def bond_angle(
dist1, dist2, bondx1, bondx2, bondy1, bondy2, bondz1, bondz2,
):
"""Get an angle."""
nm = dist1 * dist2
rrx = bondx1 * bondx2
rry = bondy1 * bondy2
rrz = bondz1 * bondz2
cos = (rrx + rry + rrz) / (nm)
if cos <= -1.0:
cos = cos + 0.000001
if cos >= 1.0:
cos = cos - 0.000001
deg = math.degrees(math.acos(cos))
return deg
[docs]def check_url_exists(
url="https://www.ctcms.nist.gov/~knc6/static/JARVIS-DFT/JVASP-77580.xml",
):
"""Check if a url exists."""
request = requests.get(url)
if request.status_code == 200:
return True
else:
return False
[docs]def volumetric_grid_reshape(data=[], final_grid=[50, 50, 50]):
"""Reshape volumetric data."""
import torch
data = torch.tensor(data).unsqueeze(0).unsqueeze(0)
new_data = (
torch.nn.functional.interpolate(
data,
size=final_grid,
scale_factor=None,
mode="trilinear",
align_corners=True,
recompute_scale_factor=None,
)
.squeeze()
.squeeze()
)
return new_data.numpy()
# def is_xml_valid(xsd="jarvisdft.xsd", xml="JVASP-1002.xml"):
# """Check if XML is valid."""
# xml_file = etree.parse(xml)
# xml_validator = etree.XMLSchema(file=xsd)
# is_valid = xml_validator.validate(xml_file)
# return is_valid