# -*- coding: utf-8 -*-
# pylint: disable=relative-beyond-top-level
"""Parsing functions for the oxidation state mining project"""
import concurrent.futures
import re
from collections import defaultdict
from typing import Dict, List, Tuple
import numpy as np
from ccdc import io # pylint: disable=import-error
from numeral import roman2int
from tqdm import tqdm
from .utils import SymbolNameDict
__all__ = ["GetOxStatesCSD"]
[docs]class GetOxStatesCSD: # pylint:disable=too-many-instance-attributes
"""Main parsing class"""
[docs] def __init__(self, cds_ids: List[str]) -> None:
"""Parses CSD structures for oxidation states
Args:
cds_ids (List[str]): list of CSD database identifiers
Returns:
None
"""
# Set up dictionaries and regex
self.symbol_name_dict = SymbolNameDict().get_symbol_name_dict()
self.name_symbol_dict = {v: k for k, v in self.symbol_name_dict.items()}
symbol_regex = "|".join(list(self.symbol_name_dict.values()))
self.symbol_regex = re.compile(symbol_regex)
self.regex = re.compile(
"((?:{})\\([iv0]+\\))".format(symbol_regex), re.IGNORECASE
)
self.not_ox_regex = re.compile(
"((?:{})[^\\(]*$)".format(symbol_regex), re.IGNORECASE
)
self.negative_regex = re.compile(
"((?:{})\\(-[1234567890]+\\))".format(symbol_regex), re.IGNORECASE
)
self.csd_ids = cds_ids
self.csd_reader = io.EntryReader("CSD")
def _get_symbol_ox_number(self, parsed_string: str) -> Tuple[str, int]:
"""Splits a parser hit into symbol and ox nuber and returns
latter as a integer
Args:
parsed_string (str): regex match of the form metalname(romanoxidationstate)
Returns:
str: symbol
int: oxidation number
"""
name, roman = parsed_string.strip(")").split("(")
if roman != "0":
return self.name_symbol_dict[name.lower()], roman2int(roman)
return self.name_symbol_dict[name.lower()], int(0)
def _get_symbol_negative_ox_number(self, parsed_string: str) -> Tuple[str, int]:
"""Returns a tuple(symbol, int) for negative oxidation numbers"""
name, roman = parsed_string.strip(")").split("(")
return self.name_symbol_dict[name.lower()], int(roman)
def _get_symbol_nan(self, parsed_string: str) -> Tuple[str, int]:
"""Returns a tuple(symbol, np.nan)"""
name = self.symbol_regex.findall(parsed_string)[0]
return self.name_symbol_dict[name.lower()], np.nan
[docs] def parse_name(self, chemical_name_string: str) -> dict:
"""Takes the chemical name string from the CSD database and returns,
if it finds it, a dictionary with the oxidation states for the metals
Args:
chemical_name_string (str): full chemical name
Returns:
dict: dictionary of symbol: oxidation states (list)
"""
oxidation_state_dict = defaultdict(list)
matches = re.findall(self.regex, chemical_name_string)
for match in matches:
symbol, oxidation_int = self._get_symbol_ox_number(match)
oxidation_state_dict[symbol].append(oxidation_int)
no_ox_matches = re.findall(self.not_ox_regex, chemical_name_string)
for match in no_ox_matches:
symbol, oxidation_int = self._get_symbol_nan(match)
oxidation_state_dict[symbol].append(oxidation_int)
matches = re.findall(self.negative_regex, chemical_name_string)
for match in matches:
symbol, oxidation_int = self._get_symbol_negative_ox_number(match)
oxidation_state_dict[symbol].append(oxidation_int)
return dict(oxidation_state_dict)
[docs] def parse_csd_entry(self, database_id: str) -> dict:
"""Looks up a CSD id and runs the parsing
Args:
database_id (str): CSD database identifier
Returns:
dict: symbol - oxidation state dictionary
Exception:
returns empy dict upon exception
(if it cannot find the structure in the database)
"""
try:
entry_object = self.csd_reader.entry(
database_id
) # pylint:disable=no-member
name = entry_object.chemical_name
return self.parse_name(name)
except Exception: # pylint: disable=broad-except
return {}
[docs] def run_parsing(self, njobs: int = 4) -> Dict[str, dict]:
"""Runs (concurrent) parsing over the list of database identifiers.
Args:
njobs (int): maximum number of parallel workers
Returns:
Dict[str, dict]: nested dictionary
with {'id': {'symbol': [oxidation states]}}
"""
results_dict = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=njobs) as executor:
for database_id, result in tqdm(
zip(self.csd_ids, executor.map(self.parse_csd_entry, self.csd_ids)),
total=len(self.csd_ids),
):
results_dict[database_id] = result
return results_dict