"""Set of useful utility functions."""

from collections import OrderedDict
from collections import defaultdict
import random
import numpy as np
import math
import xmltodict
import re
import requests

[docs]def xml_to_dict(fname): """Parse XML file.""" with open(fname, "r") as f: data = xmltodict.parse( return data
[docs]def get_factors(x, start=2): """Get factors of a number.""" facts = [] for i in range(start, x + 1): if x % i == 0: facts.append(i) return facts
[docs]def get_counts(array=["W", "W", "Mo", "Mo", "S", "S"]): """ Get number of unique elements and their counts. Uses OrderedDict. Args: array of elements Returns: ordereddict, e.g.OrderedDict([('W', 2), ('Mo', 2), ('S', 2)]) """ uniqe_els = [] for i in array: if i not in uniqe_els: uniqe_els.append(i) info = OrderedDict() for i in uniqe_els: info.setdefault(i, 0) for i in array: info[i] += 1 return info
[docs]def gcd(a, b): """Calculate the Greatest Common Divisor of a and b. Unless b==0, the result will have the same sign as b (so that when b is divided by it, the result comes out positive). """ while b: a, b = b, a % b return a
[docs]def ext_gcd(a, b): """GCD module from ase.""" if b == 0: return 1, 0 elif a % b == 0: return 0, 1 else: x, y = ext_gcd(b, a % b) return y, x - y * (a // b)
[docs]def rand_select(x=[]): """Select randomly with index info.""" info = {} for i, ii in enumerate(x): info.setdefault(ii, []).append(i) selected = {} for i, j in info.items(): chosen = random.choice(j) selected.setdefault(i, chosen) return selected
[docs]def rec_dict(): """Make a recursion dictionary.""" return defaultdict(rec_dict)
[docs]def random_colors(number_of_colors=110): """Generate random colors for atom coloring.""" colors = [ "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)]) for i in range(number_of_colors) ] color_dict = {} for i, ii in enumerate(colors): color_dict[i] = ii return color_dict
[docs]def get_angle( a=np.array([1, 2, 3]), b=np.array([4, 5, 6]), c=np.array([7, 8, 9]) ): """Get angle between three vectors.""" # theta = argcos(x.y/(|x||y|)) cos = - b), (c - b)) / ( np.linalg.norm((a - b)) * np.linalg.norm((c - b)) ) if cos <= -1.0: cos = cos + 0.000001 if cos >= 1.0: cos = cos - 0.000001 angle = math.degrees(math.acos(cos)) return angle
[docs]def recast_array_on_uniq_array_elements( uniq=["Si", "Al", "O"], arr=["Si", "Si", "Al", "Al", "Si", "O", "O", "O", "O"], ): """Recast array on uniq array elements.""" info = {} for i, ii in enumerate(uniq): for j, jj in enumerate(arr): if ii == jj: info.setdefault(ii, []).append(j) return info
[docs]def lorentzian(x, y, x0, gamma): """Get Lorentzian of a function.""" return (y / math.pi) * ( (0.5 * gamma) / ((x - x0) ** 2 + (0.5 * gamma) ** 2) )
[docs]def stringdict_to_xml(d={}, enforce_string=False): """Convert string dictionary to XML.""" line = "" for i, j in d.items(): if enforce_string: line += "<" + str(i) + ">'" + str(j) + "'</" + str(i) + ">" else: line += "<" + str(i) + ">" + str(j) + "</" + str(i) + ">" return line
[docs]def array_to_string(arr=[]): """Convert 1D arry to string.""" return ",".join(map(str, arr))
[docs]def chunks(lst, n): """Split successive n-sized chunks from list.""" x = [] for i in range(0, len(lst), n): x.append(lst[i : i + n]) return x
[docs]def check_match(a, b, tol=1e-8): """Check if a and b are the same, taking into account PBCs.""" if abs(a[0] - b[0]) < tol or abs(abs(a[0] - b[0]) - 1) < tol: if abs(a[1] - b[1]) < tol or abs(abs(a[1] - b[1]) - 1) < tol: if abs(a[2] - b[2]) < tol or abs(abs(a[2] - b[2]) - 1) < tol: return True return False
[docs]def update_dict(main={}, extra={}): """Return update dictionary.""" # Helper function for dict.update method tmp = main.copy() for i, j in extra.items(): tmp[i] = j return tmp
[docs]def get_new_coord_for_xyz_sym(frac_coord=[], xyz_string=""): """Obtain new coord from xyz string.""" affine_matrix = parse_xyz_string(xyz_string) coord = operate_affine(frac_coord, affine_matrix) coord = np.array([i - math.floor(i) for i in coord]) return coord
[docs]def check_duplicate_coords(coords=[], coord=[]): """Check if a coordinate exists.""" positive = False for i in coords: tmp = check_match(i, coord) if tmp: positive = True return positive
[docs]def parse_xyz_string(xyz_string): """ Convert xyz info to translation and rotation vectors. Adapted from pymatgen. Args: xyz_string: string of the form 'x, y, z', '-x, -y, z', '-2y+1/2, 3x+1/2, z-y+1/2', etc. Returns: translation and rotation vectors. """ rot_matrix = np.zeros((3, 3)) trans = np.zeros(3) toks = xyz_string.strip().replace(" ", "").lower().split(",") re_rot = re.compile(r"([+-]?)([\d\.]*)/?([\d\.]*)([x-z])") re_trans = re.compile(r"([+-]?)([\d\.]+)/?([\d\.]*)(?![x-z])") for i, tok in enumerate(toks): # build the rotation matrix for m in re_rot.finditer(tok): factor = -1 if == "-" else 1 if != "": factor *= ( float( / float( if != "" else float( ) j = ord( - 120 rot_matrix[i, j] = factor # build the translation vector for m in re_trans.finditer(tok): factor = -1 if == "-" else 1 num = ( float( / float( if != "" else float( ) trans[i] = num * factor affine_matrix = np.eye(4) affine_matrix[0:3][:, 0:3] = rot_matrix affine_matrix[0:3][:, 3] = trans return affine_matrix
[docs]def operate_affine(cart_coord=[], affine_matrix=[]): """Operate affine method.""" affine_point = np.array([cart_coord[0], cart_coord[1], cart_coord[2], 1]) return, affine_point)[0:3]
[docs]def gaussian(x, sigma): """Get Gaussian profile.""" return np.exp(-(x ** 2) / (2 * sigma ** 2))
[docs]def lorentzian2(x, gamma): """Get Lorentziann profile.""" return ( gamma / 2 / (np.pi * (x ** 2 + (gamma / 2) ** 2)) / (2 / (np.pi * gamma)) )
[docs]def digitize_array(values=[], max_len=10): """Digitze an array.""" has_float = False in [float(i).is_integer() for i in values] if has_float: arr = np.array([float(i) for i in values]) max_val = max(arr) min_val = min(arr) bins = np.arange(1, max_len + 1) * (max_val - min_val) / 10 values = np.digitize(arr, bins) return values
[docs]def bond_angle( dist1, dist2, bondx1, bondx2, bondy1, bondy2, bondz1, bondz2, ): """Get an angle.""" nm = dist1 * dist2 rrx = bondx1 * bondx2 rry = bondy1 * bondy2 rrz = bondz1 * bondz2 cos = (rrx + rry + rrz) / (nm) if cos <= -1.0: cos = cos + 0.000001 if cos >= 1.0: cos = cos - 0.000001 deg = math.degrees(math.acos(cos)) return deg
[docs]def check_url_exists( url="", ): """Check if a url exists.""" request = requests.get(url) if request.status_code == 200: return True else: return False
[docs]def volumetric_grid_reshape(data=[], final_grid=[50, 50, 50]): """Reshape volumetric data.""" import torch data = torch.tensor(data).unsqueeze(0).unsqueeze(0) new_data = ( torch.nn.functional.interpolate( data, size=final_grid, scale_factor=None, mode="trilinear", align_corners=True, recompute_scale_factor=None, ) .squeeze() .squeeze() ) return new_data.numpy()
[docs]def cos_formula(a, b, c): """Get angle between three edges for oblique triangles.""" res = (a ** 2 + b ** 2 - c ** 2) / (2 * a * b) res = -1.0 if res < -1.0 else res res = 1.0 if res > 1.0 else res return np.arccos(res)
# def is_xml_valid(xsd="jarvisdft.xsd", xml="JVASP-1002.xml"): # """Check if XML is valid.""" # xml_file = etree.parse(xml) # xml_validator = etree.XMLSchema(file=xsd) # is_valid = xml_validator.validate(xml_file) # return is_valid