diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 773b310..1f20498 100755
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -16,14 +16,14 @@ jobs:
max-parallel: 1
matrix:
- python-version: ['3.10']
+ python-version: ['3.11']
name: Python ${{ matrix.python-version }} Test Pop
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- - uses: actions/setup-python@v2
+ - uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
@@ -31,19 +31,20 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install setuptools setuptools_scm wheel
- pip install pydocstyle black flake8 isort
+ pip install pydocstyle black flake8 ruff
pip install -e .[tests]
- - name: sort imports
+ - name: ruff
run: |
- isort . --atomic --profile black
+ ruff --version
+ ruff --fix shakenbreak
- name: check docstrings
run: |
pydocstyle --version
pydocstyle -e --count --convention=google --add-ignore=D400,D415,D212,D205,D417,D107 shakenbreak
- - name: black
+ - name: black # max line length 107 specified in pyproject.toml
run: |
black --version
black --color shakenbreak
diff --git a/.github/workflows/pip_install_test.yml b/.github/workflows/pip_install_test.yml
index d3607a8..795e827 100755
--- a/.github/workflows/pip_install_test.yml
+++ b/.github/workflows/pip_install_test.yml
@@ -17,15 +17,20 @@ jobs:
fail-fast: false
matrix:
+ os: [ ubuntu-latest, macos-14 ]
python-version: [ '3.9', '3.10', '3.11' ]
- os: [ ubuntu-latest,macos-latest ]
+ exclude:
+ - os: macos-14
+ python-version: '3.9' # Exclude Python 3.9 on macOS, not supported for macOS-14 tests
+ # macos-latest should be 14 according to link below, but currently doesn't?
+ # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories
runs-on: ${{matrix.os}}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- - uses: actions/setup-python@v2
+ - uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
@@ -36,20 +41,27 @@ jobs:
pip install shakenbreak[tests] # install only from PyPI
- name: Test
- run: |
- pytest --mpl -vv tests # test everything
-
+ run: |
+ pytest -vv -m "not mpl_image_compare" tests # all non-plotting tests
+
+ - name: Plotting Tests
+ if: always() # run even if non-plotting tests fail
+ id: plotting_tests # Add an ID to this step for reference
+ run: |
+ pytest --mpl -m "mpl_image_compare" tests/test_plotting.py # plotting tests
+ pytest --mpl -m "mpl_image_compare" tests/test_shakenbreak.py # plotting tests
+
- name: Generate GH Actions test plots
- if: always() # always generate the plots, even if the tests fail
- run: |
+ if: failure() && steps.plotting_tests.outcome == 'failure' # Run only if plotting tests fail
+ run: |
# Generate the test plots in case there were any failures:
- pytest --mpl-generate-path=tests/remote_baseline tests/test_plotting.py
- pytest --mpl-generate-path=tests/remote_baseline tests/test_shakenbreak.py
+ pytest --mpl-generate-path=tests/remote_baseline -m "mpl_image_compare" tests/test_plotting.py
+ pytest --mpl-generate-path=tests/remote_baseline -m "mpl_image_compare" tests/test_shakenbreak.py
# Upload test plots
- name: Archive test plots
if: always() # always upload the plots, even if the tests fail
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: output-plots
path: tests/remote_baseline
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index a54a01b..25d6217 100755
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -18,9 +18,9 @@ jobs:
# only run when tests have passed (or manually triggered)
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- - uses: actions/setup-python@v3
+ - uses: actions/setup-python@v5
with:
python-version: "3.10"
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 74b5de9..f8c64d4 100755
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,14 +1,9 @@
name: Tests
on:
- pull_request:
- branches:
- - main
- - develop
push:
branches:
- - main
- - develop
+ - '*' # all branches
workflow_dispatch:
@@ -18,14 +13,19 @@ jobs:
fail-fast: false
matrix:
- python-version: ['3.9', '3.10', '3.11']
- os: [ubuntu-latest,macos-latest]
+ os: [ ubuntu-latest, macos-14 ]
+ python-version: [ '3.9', '3.10', '3.11' ]
+ exclude:
+ - os: macos-14
+ python-version: '3.9' # Exclude Python 3.9 on macOS, not supported for macOS-14 tests
+ # macos-latest should be 14 according to link below, but currently doesn't?
+ # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories
runs-on: ${{matrix.os}}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- - uses: actions/setup-python@v2
+ - uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
@@ -39,27 +39,30 @@ jobs:
pip show -V pymatgen-analysis-defects
pip show -V pymatgen
pip show -V pytest
+ pip show -V doped
- name: Test
run: |
- pytest --mpl -vv tests # test everything
+ pytest -vv -m "not mpl_image_compare" tests # all non-plotting tests
+
+ - name: Plotting Tests
+ if: always() # run even if non-plotting tests fail
+ id: plotting_tests # Add an ID to this step for reference
+ run: |
+ pytest --mpl -m "mpl_image_compare" tests/test_plotting.py # plotting tests
+ pytest --mpl -m "mpl_image_compare" tests/test_shakenbreak.py # plotting tests
- name: Generate GH Actions test plots
- if: always() # always generate the plots, even if the tests fail
- run: |
+ if: failure() && steps.plotting_tests.outcome == 'failure' # Run only if plotting tests fail
+ run: |
# Generate the test plots in case there were any failures:
- pytest --mpl-generate-path=tests/remote_baseline tests/test_plotting.py
- pytest --mpl-generate-path=tests/remote_baseline tests/test_shakenbreak.py
+ pytest --mpl-generate-path=tests/remote_baseline -m "mpl_image_compare" tests/test_plotting.py
+ pytest --mpl-generate-path=tests/remote_baseline -m "mpl_image_compare" tests/test_shakenbreak.py
# Upload test plots
- name: Archive test plots
if: always() # always upload the plots, even if the tests fail
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: output-plots
path: tests/remote_baseline
-
- # - name: Download a single artifact
- # uses: actions/download-artifact@v3
- # with:
- # name: output-plots
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 6999080..19af917 100755
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -2,7 +2,12 @@
# See https://pre-commit.com/hooks.html for more hooks
exclude: ^(docs|tests|SnB_input_files|.github|shakenbreak/scripts|CITATION*|MANIFEST*)
repos:
-
+ # Lint and format, isort, docstrings...
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
+ rev: v0.3.7
+ hooks:
+ - id: ruff
+ args: [--fix]
# Remove trailing whitespace, leave empty line at end of file
- repo: https://github.com/pre-commit/pre-commit-hooks
@@ -12,18 +17,6 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
-# - repo: https://github.com/pre-commit/mirrors-mypy
-# rev: v0.971
-# hooks:
-# - id: mypy
-
-# Sort/format imports
- - repo: https://github.com/PyCQA/isort
- rev: 5.12.0
- hooks:
- - id: isort
- args: [--profile, black]
-
# Check docstrings
- repo: https://github.com/pycqa/pydocstyle
rev: 4.0.0 # pick a git hash / tag to point to
@@ -32,8 +25,8 @@ repos:
args: [-e, --count, "--convention=google", "--add-ignore=D107,D202,D400,D415,D212,D205,D417,D413"]
# Code formatting
- - repo: https://github.com/psf/black
- rev: 22.6.0
+ - repo: https://github.com/psf/black # max line length 107 specified in pyproject.toml
+ rev: 24.2.0
hooks:
- id: black
args: [--color]
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index f5a3ce5..bd036a4 100755
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -6,9 +6,9 @@
version: 2
build:
- os: ubuntu-20.04
+ os: ubuntu-22.04
tools:
- python: "3.10"
+ python: "3.11"
# Build from the docs/ directory with Sphinx
sphinx:
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 1bf4e96..50890f4 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,6 +1,21 @@
Change Log
==========
+v3.3.3
+----------
+- Add ``verbose`` option to more parsing/plotting functions for better control of output detail.
+- Improve effiency & robustness of oxidation state handling.
+- Miscellaneous efficiency (e.g. memory reduction) and robustness updates.
+- Improved GitHub Actions test efficiency.
+
+v3.3.2
+----------
+- Add ``verbose`` options to ``io.parse_energies()`` and ``snb-parse``, also used in ``snb-plot`` and
+ ``snb-analyse``, and set to ``False`` by default to reduce verbosity of certain SnB CLI commands.
+- Use ``doped`` functions to make oxi-state guessing (and thus defect initialisation) more efficient.
+- Miscellaneous efficiency and robustness updates.
+- Testing updates.
+
v3.3.1
----------
- ``distortion_metadata.json`` for each defect now saved to the individual defect folders (as well as the
diff --git a/README.md b/README.md
index 1ebc830..2e32f64 100644
--- a/README.md
+++ b/README.md
@@ -128,20 +128,22 @@ Automatic testing is run on the master and develop branches using Github Actions
## Studies using `ShakeNBreak`
+- B. E. Murdock et al. **_Li-Site Defects Induce Formation of Li-Rich Impurity Phases: Implications for Charge Distribution and Performance of LiNi0.5-xMxMn1.5O4 Cathodes (M = Fe and Mg; x = 0.05–0.2)_** [_Advanced Materials_](https://doi.org/10.1002/adma.202400343) 2024
+- A. G. Squires et al. **_Oxygen dimerization as a defect-driven process in bulk LiNiO22_** [_ChemRxiv_](https://doi.org/10.26434/chemrxiv-2024-lcmkj) 2024
- X. Wang et al. **_Upper efficiency limit of Sb2Se3 solar cells_** [_arXiv_](https://arxiv.org/abs/2402.04434) 2024
- I. Mosquera-Lois et al. **_Machine-learning structural reconstructions for accelerated point defect calculations_** [_arXiv_](https://doi.org/10.48550/arXiv.2401.12127) 2024
-- K. Li et al. **_Computational Prediction of an Antimony-based n-type Transparent Conducting Oxide: F-doped Sb2O5_** [_ChemRxiv_](https://chemrxiv.org/engage/chemrxiv/article-details/65846b8366c1381729bc5f23) 2023
+- K. Li et al. **_Computational Prediction of an Antimony-based n-type Transparent Conducting Oxide: F-doped Sb2O5_** [_Chemistry of Materials_](https://doi.org/10.1021/acs.chemmater.3c03257) 2024
- X. Wang et al. **_Four-electron negative-U vacancy defects in antimony selenide_** [_Physical Review B_](https://journals.aps.org/prb/abstract/10.1103/PhysRevB.108.134102) 2023
- Y. Kumagai et al. **_Alkali Mono-Pnictides: A New Class of Photovoltaic Materials by Element Mutation_** [_PRX Energy_](http://dx.doi.org/10.1103/PRXEnergy.2.043002) 2023
- A. T. J. Nicolson et al. **_Cu2SiSe3 as a promising solar absorber: harnessing cation dissimilarity to avoid killer antisites_** [_Journal of Materials Chemistry A_](https://doi.org/10.1039/D3TA02429F) 2023
- J. Willis, K. B. Spooner, D. O. Scanlon **_On the possibility of p-type doping in barium stannate_** [_Applied Physics Letters_](https://doi.org/10.1063/5.0170552) 2023
-
- J. Cen et al. **_Cation disorder dominates the defect chemistry of high-voltage LiMn1.5Ni0.5O4 (LMNO) spinel cathodes_** [_Journal of Materials Chemistry A_](https://doi.org/10.1039/D3TA00532A) 2023
-- J. Willis & R. Claes et al. **_Limits to Hole Mobility and Doping in Copper Iodide_** [_Chem Mater_](https://doi.org/10.1021/acs.chemmater.3c01628) 2023
+- J. Willis & R. Claes et al. **_Limits to Hole Mobility and Doping in Copper Iodide_** [_Chemistry of Materials_](https://doi.org/10.1021/acs.chemmater.3c01628) 2023
- I. Mosquera-Lois & S. R. Kavanagh, A. Walsh, D. O. Scanlon **_Identifying the ground state structures of point defects in solids_** [_npj Computational Materials_](https://www.nature.com/articles/s41524-023-00973-1) 2023
+- B. Peng et al. **_Advancing understanding of structural, electronic, and magnetic properties in 3d-transition-metal TM-doped α-Ga₂O₃ (TM = V, Cr, Mn, and Fe)_** [_Journal of Applied Physics_](https://doi.org/10.1063/5.0173544) 2023
- Y. T. Huang & S. R. Kavanagh et al. **_Strong absorption and ultrafast localisation in NaBiS2 nanocrystals with slow charge-carrier recombination_** [_Nature Communications_](https://www.nature.com/articles/s41467-022-32669-3) 2022
- S. R. Kavanagh, D. O. Scanlon, A. Walsh, C. Freysoldt **_Impact of metastable defect structures on carrier recombination in solar cells_** [_Faraday Discussions_](https://doi.org/10.1039/D2FD00043A) 2022
-- Y-S. Choi et al. **_Intrinsic Defects and Their Role in the Phase Transition of Na-Ion Anode Na2Ti3O7_** [_ACS Appl. Energy Mater._](https://doi.org/10.1021/acsaem.2c03466) 2022 (Early version)
+- Y-S. Choi et al. **_Intrinsic Defects and Their Role in the Phase Transition of Na-Ion Anode Na2Ti3O7_** [_ACS Applied Energy Materials_](https://doi.org/10.1021/acsaem.2c03466) 2022 (Early version)
- S. R. Kavanagh, D. O. Scanlon, A. Walsh **_Rapid Recombination by Cadmium Vacancies in CdTe_** [_ACS Energy Letters_](https://pubs.acs.org/doi/full/10.1021/acsenergylett.1c00380) 2021
- C. J. Krajewska et al. **_Enhanced visible light absorption in layered Cs3Bi2Br9 through mixed-valence Sn(II)/Sn(IV) doping_** [_Chemical Science_](https://doi.org/10.1039/D1SC03775G) 2021 (Early version)
- (News & Views): A. Mannodi-Kanakkithodi **_The devil is in the defects_** [_Nature Physics_](https://doi.org/10.1038/s41567-023-02049-9) 2023 ([Free-to-read link](https://t.co/EetpnRgjzh))
diff --git a/docs/Contributing.rst b/docs/Contributing.rst
index 6c6dca1..1c367bb 100644
--- a/docs/Contributing.rst
+++ b/docs/Contributing.rst
@@ -27,12 +27,12 @@ workflow to do so and follow the `PEP8 `_ sty
.. NOTE::
Alternatively, if you prefer not to use ``pre-commit hooks``, you can manually run the following in the **correct sequence**
- on your local machine. From the ``shakenbreak`` top directory, run `isort `_ to sort and format your imports, followed by
+ on your local machine. From the ``shakenbreak`` top directory, run `ruff --fix `_ to lint and format, followed by
`black `_, which will automatically reformat the code to ``PEP8`` conventions:
.. code:: bash
- $ isort . --profile black
+ $ ruff --fix
$ black --diff --color shakenbreak
Then run `pycodestyle `_ to check the docstrings,
diff --git a/docs/conf.py b/docs/conf.py
index c6c2ed0..80d6662 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -25,7 +25,7 @@
author = 'Irea Mosquera-Lois, Seán R. Kavanagh'
# The full version, including alpha/beta/rc tags
-release = '3.3.1'
+release = '3.3.3'
# -- General configuration ---------------------------------------------------
diff --git a/docs/index.rst b/docs/index.rst
index 78d9149..5677d87 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -211,19 +211,22 @@ run tests and add new tests for any new features whenever submitting pull reques
Studies using ``ShakeNBreak``
=============================
+- B\. E. Murdock et al. **Li-Site Defects Induce Formation of Li-Rich Impurity Phases: Implications for Charge Distribution and Performance of LiNi** :sub:`0.5-x` **M** :sub:`x` **Mn** :sub:`1.5` **O** :sub:`4` **Cathodes (M = Fe and Mg; x = 0.05–0.2)** `Advanced Materials `_ 2024
+- A\. G. Squires et al. **Oxygen dimerization as a defect-driven process in bulk LiNiO₂** `ChemRxiv `_ 2024
- X\. Wang et al. **Upper efficiency limit of Sb₂Se₃ solar cells** `arXiv `_ 2024
- I\. Mosquera-Lois et al. **Machine-learning structural reconstructions for accelerated point defect calculations** `arXiv `_ 2024
-- K\. Li et al. **Computational Prediction of an Antimony-based n-type Transparent Conducting Oxide: F-doped Sb₂O₅** `ChemRxiv `_ 2024
+- K\. Li et al. **Computational Prediction of an Antimony-based n-type Transparent Conducting Oxide: F-doped Sb₂O₅** `Chemistry of Materials `_ 2024
- X\. Wang et al. **Four-electron negative-U vacancy defects in antimony selenide** `Physical Review B `_ 2023
- Y\. Kumagai et al. **Alkali Mono-Pnictides: A New Class of Photovoltaic Materials by Element Mutation** `PRX Energy `__ 2023
- J\. Willis, K. B. Spooner, D. O. Scanlon. **On the possibility of p-type doping in barium stannate** `Applied Physics Letters `__ 2023
- A\. T. J. Nicolson et al. **Cu₂SiSe₃ as a promising solar absorber: harnessing cation dissimilarity to avoid killer antisites** `Journal of Materials Chemistry A `__ 2023
-- J\. Cen et al. **Cation disorder dominates the defect chemistry of high-voltage LiMn** :sub:`1.5`**Ni** :sub:`0.5`**O₄ (LMNO) spinel cathodes** `Journal of Materials Chemistry A`_ 2023
-- J\. Willis & R. Claes et al. **Limits to Hole Mobility and Doping in Copper Iodide** `Chem Mater `__ 2023
+- J\. Cen et al. **Cation disorder dominates the defect chemistry of high-voltage LiMn** :sub:`1.5` **Ni** :sub:`0.5` **O₄ (LMNO) spinel cathodes** `Journal of Materials Chemistry A`_ 2023
+- J\. Willis & R. Claes et al. **Limits to Hole Mobility and Doping in Copper Iodide** `Chemistry of Materials `__ 2023
- I\. Mosquera-Lois & S. R. Kavanagh, A. Walsh, D. O. Scanlon **Identifying the ground state structures of point defects in solids** `npj Computational Materials`_ 2023
+- B\. Peng et al. **Advancing understanding of structural, electronic, and magnetic properties in 3d-transition-metal TM-doped α-Ga₂O₃ (TM = V, Cr, Mn, and Fe)** `Journal of Applied Physics `__ 2023
- Y\. T. Huang & S. R. Kavanagh et al. **Strong absorption and ultrafast localisation in NaBiS₂ nanocrystals with slow charge-carrier recombination** `Nature Communications`_ 2022
- S\. R. Kavanagh, D. O. Scanlon, A. Walsh, C. Freysoldt **Impact of metastable defect structures on carrier recombination in solar cells** `Faraday Discussions`_ 2022
-- Y-S\. Choi et al. **Intrinsic Defects and Their Role in the Phase Transition of Na-Ion Anode Na₂Ti₃O₇** `ACS Appl. Energy Mater. `__ 2022 (Early version)
+- Y-S\. Choi et al. **Intrinsic Defects and Their Role in the Phase Transition of Na-Ion Anode Na₂Ti₃O₇** `ACS Applied Energy Materials `__ 2022 (Early version)
- S\. R. Kavanagh, D. O. Scanlon, A. Walsh **Rapid Recombination by Cadmium Vacancies in CdTe** `ACS Energy Letters `__ 2021
- C\. J. Krajewska et al. **Enhanced visible light absorption in layered Cs₃Bi₂Br₉ through mixed-valence Sn(II)/Sn(IV) doping** `Chemical Science`_ 2021 (Early version)
- (News & Views): A. Mannodi-Kanakkithodi **The devil is in the defects** `Nature Physics`_ 2023 (`Free-to-read link `__)
diff --git a/pyproject.toml b/pyproject.toml
index 374b58c..0846def 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,3 +4,53 @@ requires = [
"wheel"
]
build-backend = "setuptools.build_meta"
+
+[tool.black]
+line-length = 107
+
+[tool.ruff]
+line-length = 107
+lint.pydocstyle.convention = "google"
+lint.isort.split-on-trailing-comma = false
+lint.select = [ # from pymatgen
+ "B", # flake8-bugbear
+ "C4", # flake8-comprehensions
+ "D", # pydocstyle
+ "E", # pycodestyle error
+ "EXE", # flake8-executable
+ "F", # pyflakes
+ "FLY", # flynt
+ "I", # isort
+ "ICN", # flake8-import-conventions
+ "ISC", # flake8-implicit-str-concat
+ "PD", # pandas-vet
+ "PIE", # flake8-pie
+ "PL", # pylint
+ "PT", # flake8-pytest-style
+ "PYI", # flakes8-pyi
+ "Q", # flake8-quotes
+ "RET", # flake8-return
+ "RSE", # flake8-raise
+ "RUF", # Ruff-specific rules
+ "SIM", # flake8-simplify
+ "TCH", # flake8-type-checking
+ "TID", # tidy imports
+ "TID", # flake8-tidy-imports
+ "UP", # pyupgrade
+ "W", # pycodestyle warning
+ "YTT", # flake8-2020
+]
+lint.ignore = [
+ "B028", # No explicit stacklevel keyword argument found
+ "D101", # Missing docstring in public class (docstring in init instead)
+ "D200", # One-line docstring should fit on one line with quotes
+ "D205", # 1 blank line required between summary line and description
+ "D212", # Multi-line docstring summary should start at the first line
+ "PLR2004", # Magic number
+ "PLR", # pylint refactor
+ "W605", # Invalid escape sequence
+ "PT011", # too broad pytest.raises()
+]
+
+[tool.ruff.lint.per-file-ignores]
+"tests/*" = ["D102"]
diff --git a/setup.py b/setup.py
index 725866b..336fa1d 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,4 @@
-"""This is a setup.py script to install ShakeNBreak"""
+"""This is a setup.py script to install ShakeNBreak."""
import os
import warnings
@@ -32,9 +32,7 @@ def _install_custom_font():
# Copy the font file to matplotlib's True Type font directory
fonts_dir = f"{path_to_file}/shakenbreak/"
- ttf_fonts = [
- file_name for file_name in os.listdir(fonts_dir) if ".ttf" in file_name
- ]
+ ttf_fonts = [file_name for file_name in os.listdir(fonts_dir) if ".ttf" in file_name]
try:
for font in ttf_fonts: # must be in ttf format for matplotlib
old_path = os.path.join(fonts_dir, font)
@@ -104,7 +102,7 @@ def run(self):
class CustomEggInfoCommand(egg_info):
- """Post-installation"""
+ """Post-installation."""
def run(self):
"""
@@ -120,20 +118,20 @@ def run(self):
def package_files(directory):
"""Include package data."""
paths = []
- for path, directories, filenames in os.walk(directory):
+ for path, _dir, filenames in os.walk(directory):
paths.extend(os.path.join("..", path, filename) for filename in filenames)
return paths
input_files = package_files("SnB_input_files/")
-with open("README.md", "r", encoding="utf-8") as file:
+with open("README.md", encoding="utf-8") as file:
long_description = file.read()
setup(
name="shakenbreak",
- version="3.3.1",
+ version="3.3.3",
description="Package to generate and analyse distorted defect structures, in order to "
"identify ground-state and metastable defect configurations.",
long_description=long_description,
@@ -194,7 +192,7 @@ def package_files(directory):
},
# Specify any non-python files to be distributed with the package
package_data={
- "shakenbreak": ["shakenbreak/*"] + input_files,
+ "shakenbreak": ["shakenbreak/*", *input_files],
},
include_package_data=True,
# Specify the custom installation class
diff --git a/shakenbreak/SnB_run.sh b/shakenbreak/SnB_run.sh
index 0f6ad1c..64eb165 100755
--- a/shakenbreak/SnB_run.sh
+++ b/shakenbreak/SnB_run.sh
@@ -100,6 +100,7 @@ SnB_run_loop() {
echo "Positive energies or forces error encountered for ${i%/}. "
echo "This typically indicates the initial defect structure supplied to ShakeNBreak is highly unstable, often with bond lengths smaller than the ionic radii."
echo "Please check this defect structure and/or the relaxation output files."
+ builtin cd .. || return
continue
else
echo "Positive energies or forces error encountered for ${i%/}, ignoring and renaming to ${i%/}_High_Energy"
diff --git a/shakenbreak/analysis.py b/shakenbreak/analysis.py
index 0a4c3c1..5233964 100644
--- a/shakenbreak/analysis.py
+++ b/shakenbreak/analysis.py
@@ -1,11 +1,10 @@
"""
Module containing functions to analyse rattled and bond-distorted defect
-structure relaxations
+structure relaxations.
"""
import json
import os
-import sys
import warnings
from copy import deepcopy
from typing import Optional, Union
@@ -22,9 +21,7 @@
from shakenbreak import input, io
-crystalNN = CrystalNN(
- distance_cutoffs=None, x_diff_weight=0.0, porous_adjustment=False, search_cutoff=5.0
-)
+crystalNN = CrystalNN(distance_cutoffs=None, x_diff_weight=0.0, porous_adjustment=False, search_cutoff=5.0)
def _warning_on_one_line(
@@ -47,7 +44,7 @@ def _isipython():
# Using stackoverflow.com/questions/15411967/
# how-can-i-check-if-code-is-executed-in-the-ipython-notebook
try:
- get_ipython().__class__.__name__
+ _ = get_ipython().__class__.__name__
return True
except NameError:
return False # Probably standard Python interpreter
@@ -57,23 +54,10 @@ def _isipython():
from IPython.display import display
-class _HiddenPrints:
- """Block calls to print."""
-
- # https://stackoverflow.com/questions/8391411/how-to-block-calls-to-print
- def __enter__(self):
- self._original_stdout = sys.stdout
- sys.stdout = open(os.devnull, "w")
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.stdout.close()
- sys.stdout = self._original_stdout
-
-
# Helper functions
def _read_distortion_metadata(output_path: str) -> dict:
"""
- Parse distortion_metadata.json file
+ Parse ``distortion_metadata.json`` file.
Args:
output_path (:obj:`str`):
@@ -88,17 +72,15 @@ def _read_distortion_metadata(output_path: str) -> dict:
try: # Read distortion parameters from distortion_metadata.json
with open(f"{output_path}/distortion_metadata.json") as json_file:
distortion_metadata = json.load(json_file)
- except FileNotFoundError:
- raise FileNotFoundError(
- f"No `distortion_metadata.json` file found in {output_path}."
- )
+ except FileNotFoundError as exc:
+ raise FileNotFoundError(f"No `distortion_metadata.json` file found in {output_path}.") from exc
return distortion_metadata
def _get_distortion_filename(distortion) -> str:
"""
- Format distortion names for file naming. (e.g. from 0.5 to
- 'Bond_Distortion_50.0%')
+ Format distortion names for file naming (e.g. from 0.5 to
+ 'Bond_Distortion_50.0%').
Args:
distortion (float or str):
@@ -116,9 +98,7 @@ def _get_distortion_filename(distortion) -> str:
else:
distortion_label = f"Bond_Distortion_{distortion:.1f}%"
elif isinstance(distortion, str):
- if "_from_" in distortion and (
- "Rattled" not in distortion and "Dimer" not in distortion
- ):
+ if "_from_" in distortion and ("Rattled" not in distortion and "Dimer" not in distortion):
distortion_label = f"Bond_Distortion_{distortion}"
# runs from other charge states
elif (
@@ -132,11 +112,7 @@ def _get_distortion_filename(distortion) -> str:
]
):
distortion_label = distortion
- elif (
- distortion == "Unperturbed"
- or distortion == "Rattled"
- or distortion == "Dimer"
- ):
+ elif distortion == "Unperturbed" or distortion == "Rattled" or distortion == "Dimer":
distortion_label = distortion # e.g. "Unperturbed"/"Rattled"/"Dimer"
else:
try: # try converting to float, in case user entered '0.5'
@@ -154,7 +130,7 @@ def _format_distortion_names(
) -> str:
"""
Formats the distortion filename to the names used internally and for
- analysis. (i.e. 'Bond_Distortion_-50.0%' -> -0.5)
+ analysis (i.e. 'Bond_Distortion_-50.0%' -> -0.5).
Args:
distortion_label (:obj:`str`):
@@ -167,38 +143,24 @@ def _format_distortion_names(
"""
distortion_label = distortion_label.strip() # remove any whitespace
if (
- "Unperturbed" in distortion_label
- or "Rattled" in distortion_label
- or "Dimer" in distortion_label
+ "Unperturbed" in distortion_label or "Rattled" in distortion_label or "Dimer" in distortion_label
) and "from" not in distortion_label:
return distortion_label
- elif distortion_label.startswith("Bond_Distortion") and distortion_label.endswith(
- "%"
- ):
+ if distortion_label.startswith("Bond_Distortion") and distortion_label.endswith("%"):
return float(distortion_label.split("Bond_Distortion_")[-1].split("%")[0]) / 100
# From other charge states
- elif distortion_label.startswith("Bond_Distortion") and (
- "_from_" in distortion_label
- ):
+ if distortion_label.startswith("Bond_Distortion") and ("_from_" in distortion_label):
# distortions from other charge state of the defect
return distortion_label.split("Bond_Distortion_")[-1]
- elif "Rattled" in distortion_label and "_from_" in distortion_label:
- return distortion_label
- elif "Dimer" in distortion_label and "_from_" in distortion_label:
- return distortion_label
- # detected as High_Energy - normally wouldn't be parsed, but for debugging purposes
- # TODO: remove this
- elif (
- distortion_label.startswith("Bond_Distortion")
- and "High_Energy" in distortion_label
+ if (
+ "Rattled" in distortion_label
+ and "_from_" in distortion_label
+ or "Dimer" in distortion_label
+ and "_from_" in distortion_label
):
- return float(distortion_label.split("Bond_Distortion_")[-1].split("%")[0]) / 100
- elif (
- "Dimer" in distortion_label or "Rattled" in distortion_label
- ) and "High_Energy" in distortion_label:
return distortion_label
- else:
- return "Label_not_recognized"
+
+ return "Label_not_recognized"
def get_gs_distortion(defect_energies_dict: dict) -> tuple:
@@ -218,20 +180,16 @@ def get_gs_distortion(defect_energies_dict: dict) -> tuple:
:obj:`tuple`:
(Energy difference, ground state bond distortion)
"""
- if not defect_energies_dict["distortions"]:
- if "Unperturbed" in defect_energies_dict:
- return 0, "Unperturbed"
+ if not defect_energies_dict["distortions"] and "Unperturbed" in defect_energies_dict:
+ return 0, "Unperturbed"
lowest_E_distortion = min(
defect_energies_dict["distortions"].values()
) # lowest energy obtained with bond distortions
if "Unperturbed" in defect_energies_dict:
- if list(defect_energies_dict["distortions"].keys()) == [
- "Rattled"
- ]: # If only Rattled
+ if list(defect_energies_dict["distortions"].keys()) == ["Rattled"]: # If only Rattled
energy_diff = (
- defect_energies_dict["distortions"]["Rattled"]
- - defect_energies_dict["Unperturbed"]
+ defect_energies_dict["distortions"]["Rattled"] - defect_energies_dict["Unperturbed"]
)
gs_distortion = "Rattled" if energy_diff < 0 else "Unperturbed"
else:
@@ -240,26 +198,20 @@ def get_gs_distortion(defect_energies_dict: dict) -> tuple:
lowest_E_distortion < defect_energies_dict["Unperturbed"]
): # if energy lower than Unperturbed
gs_distortion = list(defect_energies_dict["distortions"].keys())[
- list(defect_energies_dict["distortions"].values()).index(
- lowest_E_distortion
- )
+ list(defect_energies_dict["distortions"].values()).index(lowest_E_distortion)
] # bond distortion that led to ground-state
else:
gs_distortion = "Unperturbed"
else:
energy_diff = None
gs_distortion = list(defect_energies_dict["distortions"].keys())[
- list(defect_energies_dict["distortions"].values()).index(
- lowest_E_distortion
- )
+ list(defect_energies_dict["distortions"].values()).index(lowest_E_distortion)
]
return energy_diff, gs_distortion
-def _sort_data(
- energies_file: str, verbose: bool = True, min_e_diff: float = 0.05
-) -> tuple:
+def _sort_data(energies_file: str, verbose: bool = True, min_e_diff: float = 0.05) -> tuple:
"""
Organize bond distortion results in a dictionary, calculate energy
of ground-state defect structure relative to `Unperturbed` structure
@@ -302,10 +254,7 @@ def _sort_data(
if defect_energies_dict == {"distortions": {}}: # no parsed data
warnings.warn(f"No data parsed from {energies_file}, returning None")
return None, None, None
- if (
- len(defect_energies_dict["distortions"]) == 0
- and "Unperturbed" in defect_energies_dict
- ):
+ if len(defect_energies_dict["distortions"]) == 0 and "Unperturbed" in defect_energies_dict:
# no parsed distortion results but Unperturbed present
warnings.warn(f"No distortion results parsed from {energies_file}")
@@ -361,12 +310,8 @@ def analyse_defect_site(
if site_num:
isite = site_num - 1 # python/pymatgen indexing (starts counting from zero!)
elif vac_site:
- struct.append(
- "V", vac_site
- ) # Have to add a fake element for coordination analysis
- isite = (
- len(struct.sites) - 1
- ) # python/pymatgen indexing (starts counting from zero!)
+ struct.append("V", vac_site) # Have to add a fake element for coordination analysis
+ isite = len(struct.sites) - 1 # python/pymatgen indexing (starts counting from zero!)
else:
raise ValueError("Either site_num or vac_site must be specified")
@@ -379,10 +324,7 @@ def analyse_defect_site(
for coord, value in coordination.items():
coordination_dict = {"Coordination": coord, "Factor": round(value, 2)}
coord_list.append(coordination_dict)
- print(
- "Local order parameters (i.e. resemblance to given structural motif, "
- "via CrystalNN):"
- )
+ print("Local order parameters (i.e. resemblance to given structural motif, via CrystalNN):")
if _isipython():
display(pd.DataFrame(coord_list)) # display in Jupyter notebook
bond_lengths = [
@@ -399,8 +341,8 @@ def analyse_defect_site(
print() # spacing
if coordination is not None:
return pd.DataFrame(coord_list), bond_length_df
- else:
- return None, bond_length_df
+
+ return None, bond_length_df
def analyse_structure(
@@ -444,12 +386,8 @@ def analyse_structure(
"defect_site_index"
) # VASP indexing (starts counting from 1)
if defect_site is None: # for vacancies, get fractional coordinates
- defect_frac_coords = distortion_metadata["defects"][defect_name_without_charge][
- "unique_site"
- ]
- return analyse_defect_site(
- structure, name=defect_species, vac_site=defect_frac_coords
- )
+ defect_frac_coords = distortion_metadata["defects"][defect_name_without_charge]["unique_site"]
+ return analyse_defect_site(structure, name=defect_species, vac_site=defect_frac_coords)
return analyse_defect_site(structure, name=defect_species, site_num=defect_site)
@@ -492,27 +430,17 @@ def get_structures(
Dictionary of bond distortions and corresponding final structures.
"""
defect_structures_dict = {}
- if (
- not bond_distortions
- ): # if the user didn't specify any set of distortions, loop over subdirectories
- if not os.path.isdir(
- f"{output_path}/{defect_species}"
- ): # check if defect folder exists
- raise FileNotFoundError(
- f"Path f'{output_path}/{defect_species}' does not exist!"
- )
+ if not bond_distortions: # if the user didn't specify any set of distortions, loop over subdirectories
+ if not os.path.isdir(f"{output_path}/{defect_species}"): # check if defect folder exists
+ raise FileNotFoundError(f"Path f'{output_path}/{defect_species}' does not exist!")
distortion_subdirectories = [
i
for i in next(os.walk(f"{output_path}/{defect_species}"))[1]
- if ("Bond_Distortion" in i)
- or ("Unperturbed" in i)
- or ("Rattled" in i)
- or ("Dimer" in i)
+ if ("Bond_Distortion" in i) or ("Unperturbed" in i) or ("Rattled" in i) or ("Dimer" in i)
] # distortion subdirectories
if not distortion_subdirectories:
raise FileNotFoundError(
- f"No distortion subdirectories found in {output_path}/"
- f"{defect_species}"
+ f"No distortion subdirectories found in {output_path}/{defect_species}"
)
for distortion_subdirectory in distortion_subdirectories:
if "High_Energy" not in distortion_subdirectory:
@@ -520,9 +448,7 @@ def get_structures(
distortion_label=distortion_subdirectory
) # From subdirectory name, get the distortion label used for analysis
# e.g. from 'Bond_Distortion_-10.0% -> -0.1
- if (
- distortion != "Label_not_recognized"
- ): # If the subdirectory name is recognised
+ if distortion != "Label_not_recognized": # If the subdirectory name is recognised
try:
defect_structures_dict[distortion] = io.parse_structure(
code=code,
@@ -538,9 +464,7 @@ def get_structures(
defect_structures_dict[distortion] = "Not converged"
else:
if "Unperturbed" not in bond_distortions:
- bond_distortions.append(
- "Unperturbed"
- ) # always include unperturbed structure
+ bond_distortions.append("Unperturbed") # always include unperturbed structure
for distortion in bond_distortions:
if not (isinstance(distortion, str) and "High_Energy" in distortion):
distortion_label = _get_distortion_filename(distortion) # get filename
@@ -560,9 +484,7 @@ def get_structures(
)
defect_structures_dict[distortion] = "Not converged"
- return (
- defect_structures_dict # now contains the distortions from other charge states
- )
+ return defect_structures_dict # now contains the distortions from other charge states
def get_energies(
@@ -600,14 +522,10 @@ def get_energies(
energy_file_path = f"{output_path}/{defect_species}/{defect_species}.yaml"
if not os.path.isfile(energy_file_path):
raise FileNotFoundError(f"File {energy_file_path} not found!")
- defect_energies_dict, _e_diff, gs_distortion = _sort_data(
- energy_file_path, verbose=verbose
- )
+ defect_energies_dict, _e_diff, gs_distortion = _sort_data(energy_file_path, verbose=verbose)
if "Unperturbed" in defect_energies_dict:
for distortion, energy in defect_energies_dict["distortions"].items():
- defect_energies_dict["distortions"][distortion] = (
- energy - defect_energies_dict["Unperturbed"]
- )
+ defect_energies_dict["distortions"][distortion] = energy - defect_energies_dict["Unperturbed"]
defect_energies_dict["Unperturbed"] = 0.0
else:
warnings.warn(
@@ -616,9 +534,7 @@ def get_energies(
)
lowest_E_distortion = defect_energies_dict["distortions"][gs_distortion]
for distortion, energy in defect_energies_dict["distortions"].items():
- defect_energies_dict["distortions"][distortion] = (
- energy - lowest_E_distortion
- )
+ defect_energies_dict["distortions"][distortion] = energy - lowest_E_distortion
if units == "meV":
defect_energies_dict["distortions"] = {
k: v * 1000 for k, v in defect_energies_dict["distortions"].items()
@@ -651,20 +567,24 @@ def _calculate_atomic_disp(
output contains too many 'NaN' values, this likely needs to
be increased.
(Default: 0.5)
+ ltol (:obj:`float`):
+ Length tolerance used for structural comparison (via
+ `pymatgen`'s `StructureMatcher`).
+ (Default: 0.3)
+ angle_tol (:obj:`float`):
+ Angle tolerance used for structural comparison (via
+ `pymatgen`'s `StructureMatcher`).
+ (Default: 5)
Returns:
:obj:`tuple`:
Tuple of normalized root mean squared displacements and
normalized displacements between the two structures.
"""
- sm = StructureMatcher(
- ltol=ltol, stol=stol, angle_tol=angle_tol, primitive_cell=False, scale=True
- )
+ sm = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol, primitive_cell=False, scale=True)
struct1, struct2 = sm._process_species([struct1, struct2])
struct1, struct2, fu, s1_supercell = sm._preprocess(struct1, struct2)
- match = sm._match(
- struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=False
- )
+ match = sm._match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=False)
return None if match is None else (match[0], match[1])
@@ -712,6 +632,14 @@ def calculate_struct_comparison(
output contains too many 'NaN' values, this likely needs to
be increased.
(Default: 0.5)
+ ltol (:obj:`float`):
+ Length tolerance used for structural comparison (via
+ `pymatgen`'s `StructureMatcher`).
+ (Default: 0.3)
+ angle_tol (:obj:`float`):
+ Angle tolerance used for structural comparison (via
+ `pymatgen`'s `StructureMatcher`).
+ (Default: 5)
min_dist (:obj:`float`):
Minimum atomic displacement threshold to include in atomic
displacements sum (in Å, default 0.1 Å).
@@ -734,8 +662,7 @@ def calculate_struct_comparison(
ref_structure = defect_structures_dict[ref_structure]
except KeyError as e:
raise KeyError(
- f"Reference structure key '{ref_structure}' not found in "
- f"defect_structures_dict."
+ f"Reference structure key '{ref_structure}' not found in defect_structures_dict."
) from e
if ref_structure == "Not converged":
raise ValueError(
@@ -769,23 +696,15 @@ def calculate_struct_comparison(
)
if metric == "disp":
disp_dict[distortion] = (
- np.sum(norm_dist[norm_dist > min_dist * normalization])
- / normalization
+ np.sum(norm_dist[norm_dist > min_dist * normalization]) / normalization
) # Only include displacements above min_dist threshold,
# and remove normalization
elif metric == "max_dist":
- disp_dict[distortion] = (
- max(norm_dist) / normalization
- ) # Remove normalization
+ disp_dict[distortion] = max(norm_dist) / normalization # Remove normalization
else:
- raise ValueError(
- f"Invalid metric '{metric}'. Must be one of 'disp' or "
- f"'max_dist'."
- )
+ raise ValueError(f"Invalid metric '{metric}'. Must be one of 'disp' or 'max_dist'.")
except TypeError:
- disp_dict[
- distortion
- ] = None # algorithm couldn't match lattices. Set comparison
+ disp_dict[distortion] = None # algorithm couldn't match lattices. Set comparison
# metric to None
# warnings.warn(
# f"pymatgen StructureMatcher could not match lattices between "
@@ -851,16 +770,8 @@ def compare_structures(
normalised atomic displacement and maximum distance between
matched atomic sites), and relative energies.
"""
- if all(
- [
- structure == "Not converged"
- for key, structure in defect_structures_dict.items()
- ]
- ):
- warnings.warn(
- "All structures in defect_structures_dict are not converged. "
- "Returning None."
- )
+ if all(structure == "Not converged" for key, structure in defect_structures_dict.items()):
+ warnings.warn("All structures in defect_structures_dict are not converged. Returning None.")
return None
df_list = []
disp_dict = calculate_struct_comparison(
@@ -871,37 +782,36 @@ def compare_structures(
min_dist=min_dist,
verbose=verbose,
)
- with _HiddenPrints(): # only print "Comparing to..." once
+ max_dist_dict = calculate_struct_comparison(
+ defect_structures_dict,
+ metric="max_dist",
+ ref_structure=ref_structure,
+ stol=stol,
+ verbose=False, # only print "Comparing to..." once
+ )
+ # Check if too many 'NaN' values in disp_dict, if so, try with higher stol
+ number_of_nan = len([value for value in disp_dict.values() if value is None])
+ if number_of_nan > len(disp_dict.values()) // 3:
+ warnings.warn(
+ f"The specified tolerance {stol} seems to be too tight as"
+ " too many lattices could not be matched. Will retry with"
+ f" larger tolerance ({stol+0.4})."
+ )
max_dist_dict = calculate_struct_comparison(
defect_structures_dict,
metric="max_dist",
ref_structure=ref_structure,
- stol=stol,
- verbose=verbose,
+ stol=stol + 0.4,
+ verbose=False,
+ )
+ disp_dict = calculate_struct_comparison(
+ defect_structures_dict,
+ metric="disp",
+ ref_structure=ref_structure,
+ stol=stol + 0.4,
+ min_dist=min_dist,
+ verbose=False,
)
- # Check if too many 'NaN' values in disp_dict, if so, try with higher stol
- number_of_nan = len([value for value in disp_dict.values() if value is None])
- if number_of_nan > len(disp_dict.values()) // 3:
- warnings.warn(
- f"The specified tolerance {stol} seems to be too tight as"
- " too many lattices could not be matched. Will retry with"
- f" larger tolerance ({stol+0.4})."
- )
- max_dist_dict = calculate_struct_comparison(
- defect_structures_dict,
- metric="max_dist",
- ref_structure=ref_structure,
- stol=stol + 0.4,
- verbose=verbose,
- )
- disp_dict = calculate_struct_comparison(
- defect_structures_dict,
- metric="disp",
- ref_structure=ref_structure,
- stol=stol + 0.4,
- min_dist=min_dist,
- verbose=verbose,
- )
for distortion in defect_energies_dict["distortions"]:
try:
@@ -913,12 +823,16 @@ def compare_structures(
df_list.append(
[
distortion,
- round(disp_dict[distortion], 3) + 0
- if isinstance(disp_dict[distortion], float)
- else None,
- round(max_dist_dict[distortion], 3) + 0
- if isinstance(max_dist_dict[distortion], float)
- else None,
+ (
+ round(disp_dict[distortion], 3) + 0
+ if isinstance(disp_dict[distortion], float)
+ else None
+ ),
+ (
+ round(max_dist_dict[distortion], 3) + 0
+ if isinstance(max_dist_dict[distortion], float)
+ else None
+ ),
round(rel_energy, 2) + 0,
]
)
@@ -934,12 +848,16 @@ def compare_structures(
df_list.append(
[
"Unperturbed",
- round(disp_dict["Unperturbed"], 3) + 0
- if isinstance(disp_dict["Unperturbed"], float)
- else None,
- round(max_dist_dict["Unperturbed"], 3) + 0
- if isinstance(max_dist_dict["Unperturbed"], float)
- else None,
+ (
+ round(disp_dict["Unperturbed"], 3) + 0
+ if isinstance(disp_dict["Unperturbed"], float)
+ else None
+ ),
+ (
+ round(max_dist_dict["Unperturbed"], 3) + 0
+ if isinstance(max_dist_dict["Unperturbed"], float)
+ else None
+ ),
round(defect_energies_dict["Unperturbed"], 2) + 0,
]
)
@@ -1036,7 +954,7 @@ def get_homoionic_bonds(
for neighbour in neighbours
]
if f"{site.species_string}({site_index})" not in [
- list(element.keys())[0] for element in homoionic_bonds.values()
+ next(iter(element.keys())) for element in homoionic_bonds.values()
]: # avoid duplicates
homoionic_neighbours = {
f"{neighbour[0]}({neighbour[1]})": f"{neighbour[2]} A"
@@ -1048,13 +966,10 @@ def get_homoionic_bonds(
homoionic_neighbours
)
else:
- homoionic_bonds[
- f"{site.species_string}({site_index})"
- ] = homoionic_neighbours
+ homoionic_bonds[f"{site.species_string}({site_index})"] = homoionic_neighbours
if verbose:
print(
- f"{site.species_string}({site_index}): "
- f"{homoionic_neighbours}",
+ f"{site.species_string}({site_index}): {homoionic_neighbours}",
"\n",
)
if not homoionic_bonds and verbose:
@@ -1098,19 +1013,13 @@ def _site_magnetizations(
mag_array = np.array(list(element.values()))
total_mag = np.sum(mag_array[np.abs(mag_array) > 0.01])
if np.abs(total_mag) > threshold:
- significant_magnetizations[
- f"{structure[index].species_string}({index})"
- ] = {
+ significant_magnetizations[f"{structure[index].species_string}({index})"] = {
"Site": f"{structure[index].species_string}({index})",
- "Frac coords": [
- round(coord, 3) for coord in structure[index].frac_coords
- ],
+ "Frac coords": [round(coord, 3) for coord in structure[index].frac_coords],
"Site mag": round(total_mag, 3),
}
if isinstance(defect_site, int):
- significant_magnetizations[
- f"{structure[index].species_string}({index})"
- ].update(
+ significant_magnetizations[f"{structure[index].species_string}({index})"].update(
{
"Dist. (\u212B)": round(
structure.get_distance(i=defect_site, j=index),
@@ -1119,15 +1028,12 @@ def _site_magnetizations(
}
)
if orbital_projections:
- significant_magnetizations[
- f"{structure[index].species_string}({index})"
- ].update(
+ significant_magnetizations[f"{structure[index].species_string}({index})"].update(
{k: round(v, 3) for k, v in element.items() if k != "tot"}
# include site magnetization of each orbital
# but dont include total site magnetization again
)
- df = pd.DataFrame.from_dict(significant_magnetizations, orient="index")
- return df
+ return pd.DataFrame.from_dict(significant_magnetizations, orient="index")
def get_site_magnetizations(
@@ -1182,25 +1088,21 @@ def get_site_magnetizations(
defect_site_coords = None
if isinstance(defect_site, (list, np.ndarray)):
defect_site_coords = defect_site
- elif not defect_site: # look for defect site, in order to include the distance
- # between sites with significant magnetization and the defect
- if os.path.exists(f"{output_path}/distortion_metadata.json"):
- with open(f"{output_path}/distortion_metadata.json", "r") as f:
- try:
- defect_species_without_charge = "_".join(
- defect_species.split("_")[:-1]
- )
- defect_site_coords = json.load(f)["defects"][
- defect_species_without_charge
- ]["unique_site"]
- except KeyError:
- warnings.warn(
- f"Could not find defect {defect_species} in "
- f"distortion_metadata.json file. Will not include "
- f"distance between defect and sites with significant "
- f"magnetization."
- )
- defect_site = None
+ elif not defect_site and os.path.exists(f"{output_path}/distortion_metadata.json"):
+ # look for defect site, in order to include the distance between sites with significant
+ # magnetization and the defect
+ with open(f"{output_path}/distortion_metadata.json") as f:
+ try:
+ defect_species_without_charge = "_".join(defect_species.split("_")[:-1])
+ defect_site_coords = json.load(f)["defects"][defect_species_without_charge]["unique_site"]
+ except KeyError:
+ warnings.warn(
+ f"Could not find defect {defect_species} in "
+ f"distortion_metadata.json file. Will not include "
+ f"distance between defect and sites with significant "
+ f"magnetization."
+ )
+ defect_site = None
for distortion in distortions:
dist_label = _get_distortion_filename(distortion) # get filename
@@ -1213,9 +1115,7 @@ def get_site_magnetizations(
"Unperturbed/Rattled name."
)
continue
- structure = io.read_vasp_structure(
- f"{output_path}/{defect_species}/{dist_label}/CONTCAR"
- )
+ structure = io.read_vasp_structure(f"{output_path}/{defect_species}/{dist_label}/CONTCAR")
if not isinstance(structure, Structure):
warnings.warn(
f"Structure for {defect_species} either not converged or not "
@@ -1224,9 +1124,7 @@ def get_site_magnetizations(
continue
if isinstance(defect_site_coords, (list, np.ndarray)):
# for vacancies, append fake atom
- structure.append(
- species="V", coords=defect_site_coords, coords_are_cartesian=False
- )
+ structure.append(species="V", coords=defect_site_coords, coords_are_cartesian=False)
defect_site = -1 # index of the added fake atom
try:
@@ -1245,22 +1143,16 @@ def get_site_magnetizations(
)
continue
if verbose:
- print(
- f"Analysing distortion {distortion}. "
- f"Total magnetization: {round(outcar.total_mag, 2)}"
- )
- df = _site_magnetizations(
+ print(f"Analysing distortion {distortion}. Total magnetization: {round(outcar.total_mag, 2)}")
+ mag_df = _site_magnetizations(
outcar=outcar,
structure=structure,
threshold=threshold,
defect_site=defect_site,
orbital_projections=orbital_projections,
)
- if not df.empty:
- magnetizations[distortion] = df
+ if not mag_df.empty:
+ magnetizations[distortion] = mag_df
elif verbose:
- print(
- f"No significant magnetizations found for distortion: "
- f"{distortion} \n"
- )
+ print(f"No significant magnetizations found for distortion: {distortion} \n")
return magnetizations
diff --git a/shakenbreak/cli.py b/shakenbreak/cli.py
index 89cba85..13baaaf 100644
--- a/shakenbreak/cli.py
+++ b/shakenbreak/cli.py
@@ -1,4 +1,4 @@
-"""ShakeNBreak command-line-interface (CLI)"""
+"""ShakeNBreak command-line-interface (CLI)."""
import contextlib
import fnmatch
@@ -9,6 +9,7 @@
from subprocess import call
import click
+from doped.core import _guess_and_set_oxi_states_with_timeout, _rough_oxi_state_cost_from_comp
from doped.generation import get_defect_name_from_entry
from doped.utils.parsing import get_outcar
from doped.utils.plotting import format_defect_name
@@ -51,25 +52,24 @@ def invoke(self, ctx):
config_file = ctx.params[config_file_param_name]
if config_file is not None:
config_data = loadfn(config_file)
- for param, value in ctx.params.items():
+ for param, _val in ctx.params.items():
if (
- ctx.get_parameter_source(param)
- == click.core.ParameterSource.DEFAULT
+ ctx.get_parameter_source(param) == click.core.ParameterSource.DEFAULT
and param in config_data
):
ctx.params[param] = config_data[param]
- return super(CustomCommandClass, self).invoke(ctx)
+ return super().invoke(ctx)
return CustomCommandClass
# CLI Commands:
-CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
+CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]}
@click.group("snb", context_settings=CONTEXT_SETTINGS, no_args_is_help=True)
def snb():
- """ShakeNBreak: Defect structure-searching"""
+ """ShakeNBreak: Defect structure-searching."""
@snb.command(
@@ -111,7 +111,7 @@ def snb():
"--padding",
"-p",
help="If `--charge` or `--min-charge` & `--max-charge` are not set, "
- "defect charges will be set to the range: 0 – {Defect oxidation state}, "
+ "defect charges will be set to the range: 0 - {Defect oxidation state}, "
"with a `--padding` on either side of this range.",
default=1,
type=int,
@@ -143,8 +143,7 @@ def snb():
@click.option(
"--name",
"-n",
- help="Defect name for folder and metadata generation. Defaults to "
- "doped scheme (see tutorials).",
+ help="Defect name for folder and metadata generation. Defaults to doped scheme (see tutorials).",
default=None,
type=str,
)
@@ -271,9 +270,7 @@ def generate(
elif max_charge is not None or min_charge is not None:
if max_charge is None or min_charge is None:
- raise ValueError(
- "If using min/max defect charge, both options must be set!"
- )
+ raise ValueError("If using min/max defect charge, both options must be set!")
charge_lims = [min_charge, max_charge]
charges = list(
@@ -286,7 +283,7 @@ def generate(
if defect_object.user_charges:
warnings.warn(
"Defect charges were specified using the CLI option, but `charges` "
- "was also specified in the `--config` file – this will be ignored!"
+ "was also specified in the `--config` file -- this will be ignored!"
)
else:
defect_object.user_charges = charges # Update charge states
@@ -304,7 +301,7 @@ def generate(
# determined
if not defect_object.user_charges:
print(
- "Defect charge states will be set to the range: 0 – {Defect oxidation state}, "
+ "Defect charge states will be set to the range: 0 - {Defect oxidation state}, "
f"with a `padding = {padding}` on either side of this range."
)
Dist = input.Distortions(
@@ -390,8 +387,7 @@ def generate(
@click.option(
"--defects",
"-d",
- help="Path root directory with defect folders/files. "
- "Defaults to current directory ('./')",
+ help="Path root directory with defect folders/files. Defaults to current directory ('./')",
type=click.Path(exists=True, dir_okay=True),
default=".",
)
@@ -415,7 +411,7 @@ def generate(
"--padding",
"-p",
help="For any defects where `charge` is not set in the --config file, "
- "charges will be set to the range: 0 – {Defect oxidation state}, "
+ "charges will be set to the range: 0 - {Defect oxidation state}, "
"with a `--padding` on either side of this range.",
default=1,
type=int,
@@ -469,6 +465,22 @@ def generate_all(
for all defects in a given directory.
"""
bulk_struc = Structure.from_file(bulk)
+ # try parsing the bulk oxidation states first, for later assigning defect "oxi_state"s (i.e.
+ # fully ionised charge states):
+ # First check if the cost of guessing oxidation states is too high:
+ if _rough_oxi_state_cost_from_comp(bulk_struc.composition) > 1e6:
+ # If the cost is too high, avoid setting oxidation states as it will take too long
+ _bulk_oxi_states = False # will take very long to guess oxi_state
+ else:
+ # Otherwise, proceed with setting oxidation states using a separate process to allow timeouts
+ from multiprocessing import Queue # only import when necessary
+
+ queue = Queue()
+ _bulk_oxi_states = _guess_and_set_oxi_states_with_timeout(bulk_struc, queue=queue)
+ if _bulk_oxi_states: # Retrieve the oxidation states if successfully guessed and set
+ bulk_struc = queue.get() # oxi-state decorated structure
+ _bulk_oxi_states = {el.symbol: el.oxi_state for el in bulk_struc.composition.elements}
+
defects_dirs = os.listdir(defects)
if config is not None:
# In the config file, user can specify index/frac_coords and charges for each defect
@@ -528,7 +540,7 @@ def generate_all(
user_settings.pop(key)
def parse_defect_name(defect, defect_settings, structure_file="POSCAR"):
- """Parse defect name from file name"""
+ """Parse defect name from file name."""
defect_name = None
# if user included cif/POSCAR as part of the defect structure name, remove it
for substring in ("cif", "POSCAR", structure_file):
@@ -551,14 +563,10 @@ def parse_defect_name(defect, defect_settings, structure_file="POSCAR"):
# if user didn't specify defect names in config file,
# check if defect filename is recognised
try:
- defect_name = format_defect_name(
- defect, include_site_info_in_name=False
- )
+ defect_name = format_defect_name(defect, include_site_info_in_name=False)
except Exception:
with contextlib.suppress(Exception):
- defect_name = format_defect_name(
- f"{defect}_0", include_site_info_in_name=False
- )
+ defect_name = format_defect_name(f"{defect}_0", include_site_info_in_name=False)
if defect_name:
defect_name = defect
@@ -566,23 +574,21 @@ def parse_defect_name(defect, defect_settings, structure_file="POSCAR"):
def parse_defect_charges(defect_name, defect_settings):
charges = None
- if isinstance(defect_settings, dict):
- if defect_name in defect_settings:
- charges = defect_settings.get(defect_name).get("charges", None)
- if charges is None:
- charges = [
- defect_settings.get(defect_name).get("charge", None),
- ]
+ if isinstance(defect_settings, dict) and defect_name in defect_settings:
+ charges = defect_settings.get(defect_name).get("charges", None)
+ if charges is None:
+ charges = [
+ defect_settings.get(defect_name).get("charge", None),
+ ]
return charges # determing using padding if not set in config file
def parse_defect_position(defect_name, defect_settings):
- if defect_settings:
- if defect_name in defect_settings:
- defect_index = defect_settings.get(defect_name).get("defect_index")
- if defect_index:
- return int(defect_index), None
- defect_coords = defect_settings.get(defect_name).get("defect_coords")
- return None, defect_coords
+ if defect_settings and defect_name in defect_settings:
+ defect_index = defect_settings.get(defect_name).get("defect_index")
+ if defect_index:
+ return int(defect_index), None
+ defect_coords = defect_settings.get(defect_name).get("defect_coords")
+ return None, defect_coords
return None, None
defect_entries = []
@@ -590,9 +596,7 @@ def parse_defect_position(defect_name, defect_settings):
if os.path.isfile(f"{defects}/{defect}"):
try: # try to parse structure from it
defect_struc = Structure.from_file(f"{defects}/{defect}")
- defect_name = parse_defect_name(
- defect, defect_settings
- ) # None if not recognised
+ defect_name = parse_defect_name(defect, defect_settings) # None if not recognised
except Exception:
continue
@@ -620,23 +624,22 @@ def parse_defect_position(defect_name, defect_settings):
)
continue
if defect_file:
- defect_struc = Structure.from_file(
- os.path.join(defects, defect, defect_file)
- )
+ defect_struc = Structure.from_file(os.path.join(defects, defect, defect_file))
defect_name = parse_defect_name(defect, defect_settings)
else:
warnings.warn(f"Could not parse {defects}/{defect} as a defect, skipping.")
continue
# Check if indices are provided in config file
- defect_index, defect_coords = parse_defect_position(
- defect_name, defect_settings
- )
+ defect_index, defect_coords = parse_defect_position(defect_name, defect_settings)
defect_object = input.identify_defect(
defect_structure=defect_struc,
bulk_structure=bulk_struc,
defect_index=defect_index,
defect_coords=defect_coords,
+ oxi_state=(
+ None if _bulk_oxi_states else "Undetermined"
+ ), # guess if bulk_oxi, else "Undetermined"
)
if verbose:
site = defect_object.site
@@ -650,27 +653,20 @@ def parse_defect_position(defect_name, defect_settings):
)
# Update charges if specified in config file
- charges = parse_defect_charges(
- defect_name or defect_object.name, defect_settings
- )
+ charges = parse_defect_charges(defect_name or defect_object.name, defect_settings)
defect_object.user_charges = charges
# Add defect entry to full defects_dict
- # If charges were not specified by use, set them using padding
+ # If charges were not specified by user, set them using padding
for charge in defect_object.get_charge_states(padding=padding):
- defect_entries.append(
- input._get_defect_entry_from_defect(defect_object, charge)
- )
+ defect_entries.append(input._get_defect_entry_from_defect(defect_object, charge))
defects_dict = input._get_defects_dict_from_defects_entries(defect_entries)
# if user_charges not set for all defects, print info about how charge states will be
# determined
- if all(
- not defect_entry_list[0].defect.user_charges
- for defect_entry_list in defects_dict.values()
- ):
+ if all(not defect_entry_list[0].defect.user_charges for defect_entry_list in defects_dict.values()):
print(
- "Defect charge states will be set to the range: 0 – {Defect oxidation state}, "
+ "Defect charge states will be set to the range: 0 - {Defect oxidation state}, "
f"with a `padding = {padding}` on either side of this range."
)
# Apply distortions and write input files
@@ -795,6 +791,7 @@ def run(submit_command, job_script, job_name_option, all, verbose):
Loop through distortion subfolders for a defect, when run within a defect folder, or for all
defect folders in the current (top-level) directory if the --all (-a) flag is set, and submit
jobs to the HPC scheduler.
+
As well as submitting the initial geometry optimisations, can automatically continue and
resubmit calculations that have not yet converged (and handle those which have failed),
see: https://shakenbreak.readthedocs.io/en/latest/Generation.html#submitting-the-geometry-optimisations
@@ -857,16 +854,24 @@ def run(submit_command, job_script, job_name_option, all, verbose):
default="vasp",
show_default=True,
)
-def parse(defect, all, path, code):
+@click.option(
+ "--verbose",
+ "-v",
+ help="Print information about renamed/saved-over files.",
+ default=False,
+ is_flag=True,
+ show_default=True,
+)
+def parse(defect, all, path, code, verbose):
"""
Parse final energies of defect structures from relaxation output files.
Parsed energies are written to a `yaml` file in the corresponding defect directory.
"""
if defect:
- _ = io.parse_energies(defect, path, code)
+ _ = io.parse_energies(defect, path, code, verbose=verbose)
elif all:
defect_dirs = _parse_defect_dirs(path)
- _ = [io.parse_energies(defect, path, code) for defect in defect_dirs]
+ _ = [io.parse_energies(defect, path, code, verbose=verbose) for defect in defect_dirs]
else:
# assume current directory is the defect folder
try:
@@ -878,14 +883,14 @@ def parse(defect, all, path, code):
cwd = os.getcwd()
defect = cwd.split("/")[-1]
path = cwd.rsplit("/", 1)[0]
- _ = io.parse_energies(defect, path, code)
- except Exception:
+ _ = io.parse_energies(defect, path, code, verbose=verbose)
+ except Exception as exc:
raise Exception(
f"Could not parse defect '{defect}' in directory '{path}'. Please either specify "
f"a defect to parse (with option --defect), run from within a single defect "
f"directory (without setting --defect) or use the --all flag to parse all "
f"defects in the specified/current directory."
- )
+ ) from exc
@snb.command(
@@ -939,7 +944,7 @@ def parse(defect, all, path, code):
@click.option(
"--verbose",
"-v",
- help="Print information about identified energy lowering distortions.",
+ help="Print information about identified energy lowering distortions and renamed/saved-over files.",
default=False,
is_flag=True,
show_default=True,
@@ -956,11 +961,9 @@ def analyse_single_defect(defect, path, code, ref_struct, verbose):
defect = defect.replace("+", "") # try with old name format
if not os.path.exists(f"{path}/{defect}") or not os.path.exists(path):
- raise FileNotFoundError(
- f"Could not find {orig_defect_name} in the directory {path}."
- )
+ raise FileNotFoundError(f"Could not find {orig_defect_name} in the directory {path}.")
- _ = io.parse_energies(defect, path, code)
+ _ = io.parse_energies(defect, path, code, verbose=verbose)
defect_energies_dict = analysis.get_energies(
defect_species=defect, output_path=path, verbose=verbose
)
@@ -996,9 +999,7 @@ def analyse_single_defect(defect, path, code, ref_struct, verbose):
# Check if defect present in path:
if path == ".":
path = os.getcwd()
- if defect == os.path.basename(
- os.path.normpath(path)
- ): # remove defect from end of path if present:
+ if defect == os.path.basename(os.path.normpath(path)): # remove defect from end of path if present:
orig_path = path
path = os.path.dirname(path)
else:
@@ -1008,13 +1009,13 @@ def analyse_single_defect(defect, path, code, ref_struct, verbose):
except Exception:
try:
analyse_single_defect(defect, orig_path, code, ref_struct, verbose)
- except Exception:
+ except Exception as exc:
raise Exception(
f"Could not analyse defect '{defect}' in directory '{path}'. Please either specify a "
f"defect to analyse (with option --defect), run from within a single defect directory ("
f"without setting --defect) or use the --all flag to analyse all defects in the "
f"specified/current directory."
- )
+ ) from exc
@snb.command(
@@ -1123,7 +1124,7 @@ def analyse_single_defect(defect, path, code, ref_struct, verbose):
@click.option(
"--verbose",
"-v",
- help="Print information about identified energy lowering distortions.",
+ help="Print information about identified energy lowering distortions and renamed/saved-over files.",
default=False,
is_flag=True,
show_default=True,
@@ -1156,16 +1157,14 @@ def plot(
for defect in defect_dirs:
if verbose:
print(f"Parsing {defect}...")
- _ = io.parse_energies(defect, path, code)
+ _ = io.parse_energies(defect, path, code, verbose=verbose)
# Create defects_dict (matching defect name to charge states)
defects_wout_charge = [defect.rsplit("_", 1)[0] for defect in defect_dirs]
- defects_dict = {
- defect_wout_charge: [] for defect_wout_charge in defects_wout_charge
- }
+ defects_dict = {defect_wout_charge: [] for defect_wout_charge in defects_wout_charge}
for defect in defect_dirs:
defects_dict[defect.rsplit("_", 1)[0]].append(int(defect.rsplit("_", 1)[1]))
return plotting.plot_all_defects(
- defects_dict=defects_dict,
+ defect_charges_dict=defects_dict,
output_path=path,
add_colorbar=colorbar,
metric=metric,
@@ -1175,9 +1174,10 @@ def plot(
add_title=not no_title,
max_energy_above_unperturbed=max_energy,
verbose=verbose,
+ close_figures=True, # reduce memory usage with snb-plot with many defects at once
)
- elif defect is None:
+ if defect is None:
# assume current directory is the defect folder
if path != ".":
warnings.warn(
@@ -1192,18 +1192,14 @@ def plot(
# Check if defect present in path:
if path == ".":
path = os.getcwd()
- if defect == os.path.basename(
- os.path.normpath(path)
- ): # remove defect from end of path if present:
+ if defect == os.path.basename(os.path.normpath(path)): # remove defect from end of path if present:
orig_path = path
path = os.path.dirname(path)
else:
orig_path = None
try:
- energies_file = io.parse_energies(defect, path, code)
- defect_species = energies_file.rsplit("/", 1)[-1].replace(
- ".yaml", ""
- ) # in case '+' removed
+ energies_file = io.parse_energies(defect, path, code, verbose=verbose)
+ defect_species = energies_file.rsplit("/", 1)[-1].replace(".yaml", "") # in case '+' removed
defect_energies_dict = analysis.get_energies(
defect_species=defect_species,
output_path=path,
@@ -1223,10 +1219,8 @@ def plot(
)
except Exception:
try:
- energies_file = io.parse_energies(defect, orig_path, code)
- defect_species = energies_file.rsplit("/", 1)[-1].replace(
- ".yaml", ""
- ) # in case '+' removed
+ energies_file = io.parse_energies(defect, orig_path, code, verbose=verbose)
+ defect_species = energies_file.rsplit("/", 1)[-1].replace(".yaml", "") # in case '+' removed
defect_energies_dict = analysis.get_energies(
defect_species=defect_species,
output_path=orig_path,
@@ -1244,13 +1238,13 @@ def plot(
max_energy_above_unperturbed=max_energy,
verbose=verbose,
)
- except Exception:
+ except Exception as exc:
raise Exception(
f"Could not analyse & plot defect '{defect}' in directory '{path}'. Please either "
f"specify a defect to analyse (with option --defect), run from within a single "
f"defect directory (without setting --defect) or use the --all flag to analyse all "
f"defects in the specified/current directory."
- )
+ ) from exc
@snb.command(
@@ -1322,9 +1316,7 @@ def regenerate(path, code, filename, min_energy, metastable, verbose):
"""
if path == ".":
path = os.getcwd() # more verbose error if no defect folders found in path
- defect_charges_dict = energy_lowering_distortions.read_defects_directories(
- output_path=path
- )
+ defect_charges_dict = energy_lowering_distortions.read_defects_directories(output_path=path)
if not defect_charges_dict:
raise FileNotFoundError(
f"No defect folders found in directory '{path}'. Please check the "
@@ -1409,10 +1401,7 @@ def groundstate(
dir
for dir in os.listdir()
if os.path.isdir(dir)
- and any(
- substring in dir
- for substring in ["Bond_Distortion", "Rattled", "Unperturbed", "Dimer"]
- )
+ and any(substring in dir for substring in ["Bond_Distortion", "Rattled", "Unperturbed", "Dimer"])
): # distortion subfolders in cwd
# check if defect folders also in cwd
for dir in [dir for dir in os.listdir() if os.path.isdir(dir)]:
@@ -1421,12 +1410,8 @@ def groundstate(
defect_name = format_defect_name(dir, include_site_info_in_name=False)
except Exception:
with contextlib.suppress(Exception):
- defect_name = format_defect_name(
- f"{dir}_0", include_site_info_in_name=False
- )
- if (
- defect_name
- ): # recognised defect folder found in cwd, warn user and proceed
+ defect_name = format_defect_name(f"{dir}_0", include_site_info_in_name=False)
+ if defect_name: # recognised defect folder found in cwd, warn user and proceed
# assuming they want to just parse the distortion folders in cwd
warnings.warn(
f"Both distortion folders and defect folders (i.e. {dir}) were "
@@ -1504,10 +1489,8 @@ def mag(outcar, threshold, verbose):
abs_mag_values = [abs(m["tot"]) for m in outcar_obj.magnetization]
if (
- max(abs_mag_values)
- < threshold # no one atomic moment greater than threshold
- and sum(abs_mag_values)
- < threshold * 10 # total moment less than 10x threshold
+ max(abs_mag_values) < threshold # no one atomic moment greater than threshold
+ and sum(abs_mag_values) < threshold * 10 # total moment less than 10x threshold
):
if verbose:
print(f"Magnetisation is below threshold (<{threshold} μB/atom)")
diff --git a/shakenbreak/distortions.py b/shakenbreak/distortions.py
index d21022c..c3aface 100644
--- a/shakenbreak/distortions.py
+++ b/shakenbreak/distortions.py
@@ -1,4 +1,5 @@
-"""Module containing functions for applying distortions to defect structures"""
+"""Module containing functions for applying distortions to defect structures."""
+
import os
import sys
import warnings
@@ -6,17 +7,14 @@
import numpy as np
from ase.neighborlist import NeighborList
-from hiphive.structure_generation.rattle import (
- _probability_mc_rattle,
- generate_mc_rattled_structures,
-)
+from hiphive.structure_generation.rattle import _probability_mc_rattle, generate_mc_rattled_structures
from pymatgen.analysis.local_env import CrystalNN, MinimumDistanceNN
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor
def _warning_on_one_line(message, category, filename, lineno, file=None, line=None):
- """Format warnings output"""
+ """Format warnings output."""
return f"{os.path.split(filename)[-1]}:{lineno}: {category.__name__}: {message}\n"
@@ -36,14 +34,14 @@ def distort(
"""
Applies bond distortions to `num_nearest_neighbours` of the defect (specified
by `site_index` (for substitutions or interstitials) or `frac_coords`
- (for vacancies))
+ (for vacancies)).
Args:
structure (:obj:`~pymatgen.core.structure.Structure`):
Defect structure as a pymatgen object
num_nearest_neighbours (:obj:`int`):
Number of defect nearest neighbours to apply bond distortions to
- distortion factor (:obj:`float`):
+ distortion_factor (:obj:`float`):
The distortion factor to apply to the bond distance between the
defect and nearest neighbours. Typical choice is between 0.4 (-60%)
and 1.6 (+60%).
@@ -73,9 +71,7 @@ def distort(
atom_number = site_index - 1 # Align atom number with python 0-indexing
elif isinstance(frac_coords, np.ndarray): # Only for vacancies!
input_structure_ase.append("V") # fake "V" at vacancy
- input_structure_ase.positions[-1] = np.dot(
- frac_coords, input_structure_ase.cell
- )
+ input_structure_ase.positions[-1] = np.dot(frac_coords, input_structure_ase.cell)
atom_number = len(input_structure_ase) - 1
else:
raise ValueError(
@@ -83,15 +79,11 @@ def distort(
" or `frac_coords` provided."
)
- neighbours = (
- num_nearest_neighbours + 1
- ) # Prevent self-counting of the defect atom itself
+ neighbours = num_nearest_neighbours + 1 # Prevent self-counting of the defect atom itself
if distorted_atoms and len(distorted_atoms) >= num_nearest_neighbours:
nearest = [
(
- round(
- input_structure_ase.get_distance(atom_number, index, mic=True), 4
- ),
+ round(input_structure_ase.get_distance(atom_number, index, mic=True), 4),
index + 1,
input_structure_ase.get_chemical_symbols()[index],
)
@@ -108,9 +100,7 @@ def distort(
)
distances = [ # Get all distances between the selected atom and all other atoms
(
- round(
- input_structure_ase.get_distance(atom_number, index, mic=True), 4
- ),
+ round(input_structure_ase.get_distance(atom_number, index, mic=True), 4),
index + 1, # Indices start from 1
symbol,
)
@@ -119,34 +109,20 @@ def distort(
input_structure_ase.get_chemical_symbols(),
)
]
- distances = sorted( # Sort the distances shortest->longest
- distances, key=lambda tup: tup[0]
- )
+ distances = sorted(distances, key=lambda tup: tup[0]) # Sort the distances shortest->longest
- if (
- distorted_element
- ): # filter the neighbours that match the element criteria and are
+ if distorted_element: # filter the neighbours that match the element criteria and are
# closer than 4.5 Angstroms
nearest = [] # list of nearest neighbours
- for dist, index, element in distances[
- 1:
- ]: # starting from 1 to exclude defect atom
- if (
- element == distorted_element
- and dist < 4.5
- and len(nearest) < num_nearest_neighbours
- ):
+ for dist, index, element in distances[1:]: # starting from 1 to exclude defect atom
+ if element == distorted_element and dist < 4.5 and len(nearest) < num_nearest_neighbours:
nearest.append((dist, index, element))
# if the number of nearest neighbours not reached, add other neighbouring
# elements
if len(nearest) < num_nearest_neighbours:
for i in distances[1:]:
- if (
- len(nearest) < num_nearest_neighbours
- and i not in nearest
- and i[0] < 4.5
- ):
+ if len(nearest) < num_nearest_neighbours and i not in nearest and i[0] < 4.5:
nearest.append(i)
warnings.warn(
f"{distorted_element} was specified as the nearest neighbour "
@@ -159,9 +135,7 @@ def distort(
sys.stderr.flush() # ensure warning message printed before distortion info
verbose = True
else:
- nearest = distances[
- 1:neighbours
- ] # Extract the nearest neighbours according to distance
+ nearest = distances[1:neighbours] # Extract the nearest neighbours according to distance
distorted = [
(i[0] * distortion_factor, i[1], i[2]) for i in nearest
@@ -236,9 +210,7 @@ def apply_dimer_distortion(
atom_number = site_index - 1 # Align atom number with python 0-indexing
elif type(frac_coords) in [list, tuple, np.ndarray]: # Only for vacancies!
input_structure_ase.append("V") # fake "V" at vacancy
- input_structure_ase.positions[-1] = np.dot(
- frac_coords, input_structure_ase.cell
- )
+ input_structure_ase.positions[-1] = np.dot(frac_coords, input_structure_ase.cell)
atom_number = len(input_structure_ase) - 1
else:
raise ValueError(
@@ -257,8 +229,8 @@ def apply_dimer_distortion(
for other_site in sites[i + 1 :]:
distances[(site.index, other_site.index)] = site.distance(other_site)
# Get defect NN with smallest distance and lowest indices:
- site_indexes = min(
- distances, key=lambda k: (round(distances.get(k, 10), 3), k[0], k[1])
+ site_indexes = tuple(
+ sorted(min(distances, key=lambda k: (round(distances.get(k, 10), 3), k[0], k[1])))
)
# Set their distance to 2 A
input_structure_ase.set_distance(
@@ -422,8 +394,8 @@ def rattle(
except Exception as ex:
if "attempts" in str(ex):
continue
- else:
- raise ex
+
+ raise ex
if verbose:
warnings.warn(
@@ -435,9 +407,7 @@ def rattle(
else:
raise ex
- rattled_structure = aaa.get_structure(rattled_ase_struct)
-
- return rattled_structure
+ return aaa.get_structure(rattled_ase_struct)
def _local_mc_rattle_displacements(
@@ -461,7 +431,7 @@ def _local_mc_rattle_displacements(
Args:
atoms (:obj:`ase.Atoms`):
prototype structure
- site (:obj:`int`):
+ site_index (:obj:`int`):
index of defect, starting from 0
rattle_std (:obj:`float`):
rattle amplitude (standard deviation in normal distribution)
@@ -538,7 +508,7 @@ def scale_stdev(disp, r_min, r):
# Distance between defect and site i
dist_defect_to_i = atoms.get_distance(site_index, i, mic=True)
- for n in range(max_attempts):
+ for _ in range(max_attempts):
# generate displacement
delta_disp = rs.normal(
0.0,
@@ -558,21 +528,19 @@ def scale_stdev(disp, r_min, r):
if len(i_nbrs) == 0:
min_distance = np.inf
else:
- min_distance = np.min(
- atoms_rattle.get_distances(i, i_nbrs, mic=True)
- )
+ min_distance = np.min(atoms_rattle.get_distances(i, i_nbrs, mic=True))
# accept or reject delta_disp
if _probability_mc_rattle(min_distance, d_min, width) > rs.rand():
# accept delta_disp
break
- else:
- # revert delta_disp
- atoms_rattle[i].position -= delta_disp
+
+ # revert delta_disp
+ atoms_rattle[i].position -= delta_disp
else:
- raise Exception(f"Maxmium attempts ({n}) for atom {i}")
- displacements = atoms_rattle.positions - reference_positions
- return displacements
+ raise Exception(f"Maximum attempts ({max_attempts}) for atom {i}")
+
+ return atoms_rattle.positions - reference_positions
def _generate_local_mc_rattled_structures(
@@ -609,24 +577,26 @@ def _generate_local_mc_rattled_structures(
Args:
atoms (:obj:`ase.Atoms`):
- prototype structure
+ Prototype structure
site_index (:obj:`int`):
Index of defect site in structure (for substitutions or
interstitials), counting from 1.
- n_structures (:obj:`int`):
- number of structures to generate
+ n_configs (:obj:`int`):
+ Number of structures to generate
rattle_std (:obj:`float`):
- rattle amplitude (standard deviation in normal distribution);
+ Rattle amplitude (standard deviation in normal distribution);
note this value is not connected to the final
average displacement for the structures
d_min (:obj:`float`):
- interatomic distance used for computing the probability for each rattle
+ Interatomic distance used for computing the probability for each rattle
move
seed (:obj:`int`):
Seed for NumPy random state from which random rattle displacements
are generated. (Default: 42)
n_iter (:obj:`int`):
- number of Monte Carlo cycles
+ Number of Monte Carlo cycles
+ **kwargs:
+ Additional keyword arguments to be passed to `mc_rattle`
Returns:
:obj:`list`:
@@ -735,8 +705,7 @@ def local_mc_rattle(
atom_number = len(ase_struct) - 1
else:
raise ValueError(
- "Insufficient information to apply local rattle, no `site_index`"
- " or `frac_coords` provided."
+ "Insufficient information to apply local rattle, no `site_index` or `frac_coords` provided."
)
if stdev is None:
@@ -807,6 +776,4 @@ def local_mc_rattle(
if isinstance(frac_coords, np.ndarray):
local_rattled_ase_struct.pop(-1) # remove fake V from vacancy structure
- local_rattled_structure = aaa.get_structure(local_rattled_ase_struct)
-
- return local_rattled_structure
+ return aaa.get_structure(local_rattled_ase_struct)
diff --git a/shakenbreak/energy_lowering_distortions.py b/shakenbreak/energy_lowering_distortions.py
index 82e35ad..17e073c 100644
--- a/shakenbreak/energy_lowering_distortions.py
+++ b/shakenbreak/energy_lowering_distortions.py
@@ -35,20 +35,18 @@ def _format_distortion_directory_name(
# if a string but not Unperturbed or Rattled, add "Bond_Distortion_" to the start
distorted_distortion = f"Bond_Distortion_{distorted_distortion}"
- formatted_distorted_charge = (
- f"{'+' if distorted_charge > 0 else ''}{distorted_charge}"
- )
+ formatted_distorted_charge = f"{'+' if distorted_charge > 0 else ''}{distorted_charge}"
if isinstance(distorted_distortion, str) and "_from_" not in distorted_distortion:
return f"{output_path}/{defect_species}/{distorted_distortion}_from_{formatted_distorted_charge}"
# don't add "Bond_Distortion_" for "Unperturbed" or "Rattled"
- elif isinstance(distorted_distortion, str):
+ if isinstance(distorted_distortion, str):
return f"{output_path}/{defect_species}/{distorted_distortion}"
- else:
- return (
- f"{output_path}/{defect_species}/Bond_Distortion_"
- f"{round(distorted_distortion * 100, 1) + 0}%_from_{formatted_distorted_charge}"
- )
+
+ return (
+ f"{output_path}/{defect_species}/Bond_Distortion_"
+ f"{round(distorted_distortion * 100, 1) + 0}%_from_{formatted_distorted_charge}"
+ )
def read_defects_directories(output_path: str = "./") -> dict:
@@ -65,18 +63,11 @@ def read_defects_directories(output_path: str = "./") -> dict:
:obj:`dict`:
Dictionary mapping defect names to a list of its charge states.
"""
- list_subdirectories = list(
- next(os.walk(output_path))[1]
- ) # Only subdirectories in current directory
- for i in list(
- list_subdirectories
- ): # need to make copy of list when iterating over and
- # removing elements
+ list_subdirectories = list(next(os.walk(output_path))[1]) # Only subdirectories in current directory
+ for i in list_subdirectories.copy(): # make copy of list for iterating over and removing elements
try:
formatted_name = format_defect_name(i, include_site_info_in_name=False)
- if (
- formatted_name is None
- ): # defect folder name not recognised, remove from list
+ if formatted_name is None: # defect folder name not recognised, remove from list
list_subdirectories.remove(i)
except ValueError: # defect folder name not recognised, remove from list
list_subdirectories.remove(i)
@@ -144,6 +135,7 @@ def _compare_distortion(
order to consider them not matching (in Å, default = 0.2 Å).
verbose (:obj:`bool`):
Whether to print information message about structures being compared.
+ (Default: False)
Returns:
:obj:`dict`
@@ -174,7 +166,7 @@ def _compare_distortion(
}
if matching_distortion_dict: # if it matches _any_ other distortion
- index = list(matching_distortion_dict.keys())[0] # should only be one
+ index = next(iter(matching_distortion_dict.keys())) # should only be one
if charge not in low_energy_defects[defect][index]["charges"]:
# only print message if charge state not already stored (this can happen when using
# the --metastable option with small noise in the energies)
@@ -226,7 +218,7 @@ def _prune_dict_across_charges(
Screen through defects to check if any lower-energy distorted structures
were already found with/without bond distortions for other charge states
(i.e. found but higher energy, found but also with unperturbed, found
- but with energy lowering less than min_e_diff etc)
+ but with energy lowering less than min_e_diff etc).
Args:
low_energy_defects (dict):
@@ -253,36 +245,34 @@ def _prune_dict_across_charges(
min_dist (:obj:`float`):
Minimum atomic displacement threshold between structures, in
order to consider them not matching (in Å, default = 0.2 Å).
+ verbose (:obj:`bool`):
+ Whether to print verbose information about parsed defect
+ structures for energy-lowering distortions, if found.
+ (Default: False)
Returns:
:obj:`dict`
"""
for defect, distortion_list in low_energy_defects.items():
for distortion_dict in distortion_list:
- for charge in list(
- set(defect_pruning_dict[defect]) - set(distortion_dict["charges"])
- ):
+ charges_set = set(distortion_dict["charges"])
+ for charge in set(defect_pruning_dict[defect]) - charges_set:
imported_groundstates = [
gs_distortion
for gs_distortion in distortion_dict["bond_distortions"]
if isinstance(gs_distortion, str) and "_from_" in gs_distortion
]
orig_charges_from_imported_groundstates = [
- int(gs_distortion.split("_from_")[-1])
- for gs_distortion in imported_groundstates
+ int(gs_distortion.split("_from_")[-1]) for gs_distortion in imported_groundstates
]
# skip if groundstate is from an imported distortion, from this charge state
if charge in orig_charges_from_imported_groundstates:
continue
# charges in defect_pruning_dict that aren't already in this distortion entry
- defect_species_snb_name = (
- f"{defect}_{'+' if charge > 0 else ''}{charge}"
- )
+ defect_species_snb_name = f"{defect}_{'+' if charge > 0 else ''}{charge}"
for i in ["+", "", "+"]: # back to SnB name with "+" if all fail
- defect_species = defect_species_snb_name.replace(
- "+", i
- ) # try with and without '+'
+ defect_species = defect_species_snb_name.replace("+", i) # try with and without '+'
comparison_results = compare_struct_to_distortions(
distortion_dict["structures"][0],
defect_species,
@@ -381,7 +371,7 @@ def get_energy_lowering_distortions(
verbose (:obj:`bool`):
Whether to print verbose information about parsed defect
structures for energy-lowering distortions, if found.
- (Default: True)
+ (Default: False)
write_input_files (:obj:`bool`):
Whether to write input files for the identified distortions
(Default: False)
@@ -420,9 +410,7 @@ def get_energy_lowering_distortions(
if not defect_charges_dict:
# defect_charges_dict maps defect_name to list of charge states
defect_charges_dict = read_defects_directories(output_path=output_path)
- defect_pruning_dict = copy.deepcopy(
- defect_charges_dict
- ) # defects and charge states to analyse
+ defect_pruning_dict = copy.deepcopy(defect_charges_dict) # defects and charge states to analyse
# later comparison and pruning against other charge states
for defect in defect_charges_dict:
@@ -436,13 +424,11 @@ def get_energy_lowering_distortions(
energies_file = f"{output_path}/{defect_species}/{defect_species}.yaml"
with warnings.catch_warnings():
- if os.path.exists(energies_file) or os.path.exists(
- energies_file.replace("+", "")
- ):
+ if os.path.exists(energies_file) or os.path.exists(energies_file.replace("+", "")):
# ignore parsing warnings in case energies already parsed and output files deleted,
# _only_ if energies file already exists
warnings.simplefilter("ignore", category=UserWarning)
- energies_file = io.parse_energies(defect_species, output_path, code)
+ energies_file = io.parse_energies(defect_species, output_path, code, verbose=verbose)
defect_species = energies_file.rsplit("/", 1)[-1].replace(
".yaml", ""
) # in case '+' removed
@@ -666,7 +652,7 @@ def compare_struct_to_distortions(
Defect name including charge (e.g. 'vac_1_Cd_0')
output_path (:obj:`str`):
Path to directory with your distorted defect calculations (to
- calculate structure comparisons – needs code output/structure
+ calculate structure comparisons -- needs code output/structure
files to parse the structures).
(Default is current directory = "./")
code (:obj:`str`, optional):
@@ -686,6 +672,7 @@ def compare_struct_to_distortions(
orderto consider them not matching (in Å, default = 0.2 Å).
verbose (:obj:`bool`):
Whether to print information message about structures being compared.
+ (Default: False)
Returns:
:obj:`tuple`:
@@ -729,8 +716,7 @@ def compare_struct_to_distortions(
if not matching_sub_df.empty: # if there are any matches
unperturbed_df = matching_sub_df[
- matching_sub_df["Bond Distortion"]
- == "Unperturbed" # if present, otherwise empty
+ matching_sub_df["Bond Distortion"] == "Unperturbed" # if present, otherwise empty
]
rattled_df = matching_sub_df[
matching_sub_df["Bond Distortion"].apply(lambda x: "Rattled" in str(x))
@@ -750,9 +736,7 @@ def compare_struct_to_distortions(
matching_sub_df["Bond Distortion"].apply(lambda x: isinstance(x, str))
]
imported_sorted_distorted_df = string_vals_sorted_distorted_df[
- string_vals_sorted_distorted_df["Bond Distortion"].apply(
- lambda x: "_from_" in x
- )
+ string_vals_sorted_distorted_df["Bond Distortion"].apply(lambda x: "_from_" in x)
]
imported_sorted_distorted_float_df = imported_sorted_distorted_df.copy()
@@ -761,15 +745,11 @@ def compare_struct_to_distortions(
# needs to be done this way because 'key' in pd.sort_values()
# needs to be vectorised...
# if '%' in key then convert to float, else convert to 0 (for Rattled or Unperturbed)
- imported_sorted_distorted_float_df[
+ imported_sorted_distorted_float_df["Bond Distortion"] = imported_sorted_distorted_df[
"Bond Distortion"
- ] = imported_sorted_distorted_df["Bond Distortion"].apply(
- lambda x: float(x.split("%")[0]) / 100 if "%" in x else 0.0
- )
- imported_sorted_distorted_float_df = (
- imported_sorted_distorted_float_df.sort_values(
- by="Bond Distortion", key=abs
- )
+ ].apply(lambda x: float(x.split("%")[0]) / 100 if "%" in x else 0.0)
+ imported_sorted_distorted_float_df = imported_sorted_distorted_float_df.sort_values(
+ by="Bond Distortion", key=abs
)
# first unperturbed, then rattled, then dimer, then distortions sorted by
@@ -784,9 +764,7 @@ def compare_struct_to_distortions(
]
)
- struc_key = sorted_matching_df["Bond Distortion"].iloc[
- 0
- ] # first matching structure
+ struc_key = sorted_matching_df["Bond Distortion"].iloc[0] # first matching structure
if struc_key == "Unperturbed":
return ( # T/F, matching structure, energy_diff, distortion factor
@@ -795,23 +773,12 @@ def compare_struct_to_distortions(
defect_energies_dict[struc_key],
struc_key,
)
- else:
- # check if struc_key is in defect_structures_dict (corresponding to match in
- # unperturbed_df, rattled_df or sorted_distorted_df but not
- # imported_sorted_distorted_df
- # as keys have been reformatted to floats rather than strings for this)
- if struc_key in defect_structures_dict:
- return ( # T/F, matching structure, energy_diff, distortion factor
- True,
- defect_structures_dict[struc_key],
- defect_energies_dict["distortions"][struc_key],
- struc_key,
- )
- # else struc_key corresponds to reformatted float-from-string from imported distortion
- struc_key = imported_sorted_distorted_df["Bond Distortion"].iloc[
- 0
- ] # first matching structure
+ # check if struc_key is in defect_structures_dict (corresponding to match in
+ # unperturbed_df, rattled_df or sorted_distorted_df but not
+ # imported_sorted_distorted_df
+ # as keys have been reformatted to floats rather than strings for this)
+ if struc_key in defect_structures_dict:
return ( # T/F, matching structure, energy_diff, distortion factor
True,
defect_structures_dict[struc_key],
@@ -819,20 +786,29 @@ def compare_struct_to_distortions(
struc_key,
)
- else: # no matches
- return (
- False,
- None,
- None,
- None,
- ) # T/F, matching structure, energy_diff, distortion factor
+ # else struc_key corresponds to reformatted float-from-string from imported distortion
+ struc_key = imported_sorted_distorted_df["Bond Distortion"].iloc[0] # first matching structure
+ return ( # T/F, matching structure, energy_diff, distortion factor
+ True,
+ defect_structures_dict[struc_key],
+ defect_energies_dict["distortions"][struc_key],
+ struc_key,
+ )
+
+ # no matches
+ return (
+ False,
+ None,
+ None,
+ None,
+ ) # T/F, matching structure, energy_diff, distortion factor
def write_retest_inputs(
low_energy_defects: dict,
output_path: str = ".",
code: str = "vasp",
- input_filename: str = None,
+ input_filename: Optional[str] = None,
) -> None:
"""
Create folders with relaxation input files for testing the low-energy
@@ -908,9 +884,7 @@ def write_retest_inputs(
print(f"Writing low-energy distorted structure to {distorted_dir}")
if not os.path.exists(f"{output_path}/{defect_species}"):
- print(
- f"Directory {output_path}/{defect_species} not found, creating..."
- )
+ print(f"Directory {output_path}/{defect_species} not found, creating...")
os.mkdir(f"{output_path}/{defect_species}")
os.mkdir(distorted_dir)
@@ -984,9 +958,7 @@ def _copy_vasp_files(
for subfolder in os.listdir(f"{output_path}/{defect_species}"):
for filename in ["INCAR", "KPOINTS", "POTCAR"]:
if (
- os.path.exists(
- f"{output_path}/{defect_species}/{subfolder}/{filename}"
- )
+ os.path.exists(f"{output_path}/{defect_species}/{subfolder}/{filename}")
and not file_dict[filename]
):
shutil.copyfile(
@@ -1041,15 +1013,12 @@ def _copy_espresso_files(
else:
subfolders_with_input_files = []
for subfolder in os.listdir(f"{output_path}/{defect_species}"):
- if os.path.exists(
- f"{output_path}/{defect_species}/{subfolder}/{input_filename}"
- ):
+ if os.path.exists(f"{output_path}/{defect_species}/{subfolder}/{input_filename}"):
subfolders_with_input_files.append(subfolder)
break
if len(subfolders_with_input_files) > 0:
with open(
- f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}/"
- f"{input_filename}"
+ f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}/{input_filename}"
) as f:
params = f.read() # Read input parameters
# Write distorted structure in QE format, to then update input file
@@ -1104,15 +1073,12 @@ def _copy_cp2k_files(
else: # Check of input file present in the other distortion subfolders
subfolders_with_input_files = []
for subfolder in os.listdir(f"{output_path}/{defect_species}"):
- if os.path.exists(
- f"{output_path}/{defect_species}/{subfolder}/{input_filename}"
- ):
+ if os.path.exists(f"{output_path}/{defect_species}/{subfolder}/{input_filename}"):
subfolders_with_input_files.append(subfolder)
break
if len(subfolders_with_input_files) > 0:
shutil.copyfile(
- f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}"
- f"/{input_filename}",
+ f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}/{input_filename}",
f"{distorted_dir}/{input_filename}",
)
@@ -1150,15 +1116,12 @@ def _copy_castep_files(
else: # Check of input file present in the other distortion subfolders
subfolders_with_input_files = []
for subfolder in os.listdir(f"{output_path}/{defect_species}"):
- if os.path.exists(
- f"{output_path}/{defect_species}/{subfolder}/{input_filename}"
- ):
+ if os.path.exists(f"{output_path}/{defect_species}/{subfolder}/{input_filename}"):
subfolders_with_input_files.append(subfolder)
break
if len(subfolders_with_input_files) > 0:
shutil.copyfile(
- f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}"
- f"/{input_filename}",
+ f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}/{input_filename}",
f"{distorted_dir}/{input_filename}",
)
@@ -1198,15 +1161,12 @@ def _copy_fhi_aims_files(
else: # Check of input file present in the other distortion subfolders
subfolders_with_input_files = []
for subfolder in os.listdir(f"{output_path}/{defect_species}"):
- if os.path.exists(
- f"{output_path}/{defect_species}/{subfolder}/{input_filename}"
- ):
+ if os.path.exists(f"{output_path}/{defect_species}/{subfolder}/{input_filename}"):
subfolders_with_input_files.append(subfolder)
break
if len(subfolders_with_input_files) > 0:
shutil.copyfile(
- f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}"
- f"/{input_filename}",
+ f"{output_path}/{defect_species}/{subfolders_with_input_files[0]}/{input_filename}",
f"{distorted_dir}/{input_filename}",
)
@@ -1222,7 +1182,7 @@ def _copy_fhi_aims_files(
def write_groundstate_structure(
all: bool = True,
output_path: str = ".",
- groundstate_folder: str = None,
+ groundstate_folder: Optional[str] = None,
groundstate_filename: str = "groundstate_POSCAR",
structure_filename: str = "CONTCAR",
verbose: bool = False,
@@ -1256,6 +1216,7 @@ def write_groundstate_structure(
(Default: "CONTCAR")
verbose (:obj:`bool`):
Whether to print additional information about the generated folders.
+ (Default: False)
Returns:
None
@@ -1274,16 +1235,12 @@ def _write_single_groundstate(
energies_file = f"{output_path}/{defect_species}/{defect_species}.yaml"
with warnings.catch_warnings():
- if os.path.exists(energies_file) or os.path.exists(
- energies_file.replace("+", "")
- ):
+ if os.path.exists(energies_file) or os.path.exists(energies_file.replace("+", "")):
# ignore parsing warnings in case energies already parsed and output files deleted,
# _only_ if energies file already exists
warnings.simplefilter("ignore", category=UserWarning)
- energies_file = io.parse_energies(defect_species, output_path)
- defect_species = energies_file.rsplit("/", 1)[-1].replace(
- ".yaml", ""
- ) # in case '+' removed
+ energies_file = io.parse_energies(defect_species, output_path, verbose=verbose)
+ defect_species = energies_file.rsplit("/", 1)[-1].replace(".yaml", "") # in case '+' removed
if energies_file is None:
warnings.warn(
@@ -1293,16 +1250,12 @@ def _write_single_groundstate(
return
# Get ground state distortion
- _, _, gs_distortion = analysis._sort_data(
- energies_file=energies_file, verbose=False
- )
+ _, _, gs_distortion = analysis._sort_data(energies_file=energies_file, verbose=False)
bond_distortion = analysis._get_distortion_filename(gs_distortion)
# Origin path
- origin_path = (
- f"{output_path}/{defect_species}/{bond_distortion}/{structure_filename}"
- )
+ origin_path = f"{output_path}/{defect_species}/{bond_distortion}/{structure_filename}"
if not os.path.exists(origin_path):
raise FileNotFoundError(
f"The structure file {structure_filename} is not present in the directory "
@@ -1311,9 +1264,7 @@ def _write_single_groundstate(
# Destination path
if groundstate_folder:
- if not os.path.exists(
- f"{output_path}/{defect_species}/{groundstate_folder}"
- ):
+ if not os.path.exists(f"{output_path}/{defect_species}/{groundstate_folder}"):
os.mkdir(f"{output_path}/{defect_species}/{groundstate_folder}")
destination_path = os.path.join(
f"{output_path}/{defect_species}/",
diff --git a/shakenbreak/input.py b/shakenbreak/input.py
index 8565b7d..bde74f1 100644
--- a/shakenbreak/input.py
+++ b/shakenbreak/input.py
@@ -12,7 +12,8 @@
import shutil
import warnings
from importlib.metadata import version
-from typing import Optional, Tuple, Type, Union
+from multiprocessing import Queue
+from typing import Optional, Tuple, Union
import ase
import numpy as np
@@ -20,7 +21,12 @@
from ase.calculators.castep import Castep
from ase.calculators.espresso import Espresso
from doped import _ignore_pmg_warnings
-from doped.core import Defect, DefectEntry
+from doped.core import (
+ Defect,
+ DefectEntry,
+ _guess_and_set_oxi_states_with_timeout,
+ _rough_oxi_state_cost_from_comp,
+)
from doped.generation import DefectsGenerator, name_defect_entries
from doped.utils.parsing import (
get_defect_site_idxs_and_unrelaxed_structure,
@@ -29,7 +35,7 @@
from doped.vasp import DefectDictSet
from monty.json import MontyDecoder
from monty.serialization import dumpfn, loadfn
-from pymatgen.analysis.defects import core, thermo
+from pymatgen.analysis.defects import thermo
from pymatgen.analysis.defects.supercells import get_sc_fromstruct
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.core.structure import Composition, Element, PeriodicSite, Structure
@@ -47,17 +53,13 @@
MODULE_DIR = os.path.dirname(os.path.abspath(__file__))
default_potcar_dict = loadfn(f"{MODULE_DIR}/../SnB_input_files/default_POTCARs.yaml")
# Load default INCAR settings for the ShakeNBreak geometry relaxations
-default_incar_settings = loadfn(
- os.path.join(MODULE_DIR, "../SnB_input_files/incar.yaml")
-)
+default_incar_settings = loadfn(os.path.join(MODULE_DIR, "../SnB_input_files/incar.yaml"))
_ignore_pmg_warnings() # Ignore pymatgen POTCAR warnings
-def _warning_on_one_line(
- message, category, filename, lineno, file=None, line=None
-) -> str:
+def _warning_on_one_line(message, category, filename, lineno, file=None, line=None) -> str:
"""Output warning messages on one line."""
# To set this as warnings.formatwarning, we need to be able to take in `file`
# and `line`, but don't want to print them, so unused arguments here
@@ -115,16 +117,12 @@ def _write_distortion_metadata(
if os.path.exists(filepath):
try:
old_metadata = loadfn(os.path.join(output_path, "distortion_metadata.json"))
- if (
- old_metadata
- ): # convert charge keys back to integers (converted to strings when saved to /
+ if old_metadata: # convert charge keys back to integers (converted to strings when saved to /
# loaded from JSON)
for defect in list(old_metadata["defects"].keys()):
charges_dict = old_metadata["defects"][defect]["charges"]
old_metadata["defects"][defect] = {
- k: v
- for k, v in old_metadata["defects"][defect].items()
- if k != "charges"
+ k: v for k, v in old_metadata["defects"][defect].items() if k != "charges"
}
old_metadata["defects"][defect]["charges"] = {
int(k): v for k, v in charges_dict.items()
@@ -190,56 +188,38 @@ def _write_distortion_metadata(
)
os.rename(
filepath,
- os.path.join(
- output_path, f"distortion_metadata_{current_datetime}.json"
- ),
+ os.path.join(output_path, f"distortion_metadata_{current_datetime}.json"),
)
print(f"Combining old and new metadata in {filename}.")
# Combine old and new metadata dictionaries
for defect in old_metadata["defects"]:
- if (
- defect in new_metadata["defects"]
- ): # if defect in both metadata files
+ if defect in new_metadata["defects"]: # if defect in both metadata files
for charge in old_metadata["defects"][defect]["charges"]:
if (
charge in new_metadata["defects"][defect]["charges"]
): # if charge state in both files,
# then we update the mesh of distortions if this is the only differing
# quantity (i.e. [-0.3, 0.3] + [-0.4, -0.2, 0.2, 0.4])
- new_metadata_charge_dict_wout_distortions_list = (
- copy.deepcopy(
- new_metadata["defects"][defect]["charges"][
- charge
- ]
- )
+ new_metadata_charge_dict_wout_distortions_list = copy.deepcopy(
+ new_metadata["defects"][defect]["charges"][charge]
)
- new_metadata_charge_dict_wout_distortions_list[
- "distortion_parameters"
- ] = {
+ new_metadata_charge_dict_wout_distortions_list["distortion_parameters"] = {
k: v
for k, v in new_metadata_charge_dict_wout_distortions_list[
"distortion_parameters"
].items()
- if k
- not in ["bond_distortions", "distortion_increment"]
+ if k not in ["bond_distortions", "distortion_increment"]
}
- old_metadata_charge_dict_wout_distortions_list = (
- copy.deepcopy(
- old_metadata["defects"][defect]["charges"][
- charge
- ]
- )
+ old_metadata_charge_dict_wout_distortions_list = copy.deepcopy(
+ old_metadata["defects"][defect]["charges"][charge]
)
- old_metadata_charge_dict_wout_distortions_list[
- "distortion_parameters"
- ] = {
+ old_metadata_charge_dict_wout_distortions_list["distortion_parameters"] = {
k: v
for k, v in old_metadata_charge_dict_wout_distortions_list[
"distortion_parameters"
].items()
- if k
- not in ["bond_distortions", "distortion_increment"]
+ if k not in ["bond_distortions", "distortion_increment"]
}
if (
@@ -247,41 +227,33 @@ def _write_distortion_metadata(
== old_metadata_charge_dict_wout_distortions_list
):
if (
- new_metadata["defects"][defect]["charges"][
- charge
- ]["distortion_parameters"]
- != old_metadata["defects"][defect]["charges"][
- charge
- ]["distortion_parameters"]
+ new_metadata["defects"][defect]["charges"][charge][
+ "distortion_parameters"
+ ]
+ != old_metadata["defects"][defect]["charges"][charge][
+ "distortion_parameters"
+ ]
):
# combine bond distortions lists:
- old_bond_distortions = old_metadata["defects"][
- defect
- ]["charges"][charge]["distortion_parameters"][
- "bond_distortions"
- ]
+ old_bond_distortions = old_metadata["defects"][defect]["charges"][
+ charge
+ ]["distortion_parameters"]["bond_distortions"]
bond_distortions = old_bond_distortions + [
distortion
- for distortion in new_metadata["defects"][
- defect
- ]["charges"][charge][
- "distortion_parameters"
- ][
- "bond_distortions"
- ]
+ for distortion in new_metadata["defects"][defect]["charges"][
+ charge
+ ]["distortion_parameters"]["bond_distortions"]
if distortion not in old_bond_distortions
]
- new_metadata["defects"][defect]["charges"][
- charge
- ]["distortion_parameters"] = {
+ new_metadata["defects"][defect]["charges"][charge][
+ "distortion_parameters"
+ ] = {
"bond_distortions": bond_distortions,
**{
k: v
- for k, v in new_metadata["defects"][
- defect
- ]["charges"][charge][
- "distortion_parameters"
- ].items()
+ for k, v in new_metadata["defects"][defect]["charges"][
+ charge
+ ]["distortion_parameters"].items()
if k
not in [
"bond_distortions",
@@ -298,9 +270,9 @@ def _write_distortion_metadata(
)
continue
else: # if charge state only in old metadata, add it to file
- new_metadata["defects"][defect]["charges"][
- charge
- ] = old_metadata["defects"][defect]["charges"][charge]
+ new_metadata["defects"][defect]["charges"][charge] = old_metadata[
+ "defects"
+ ][defect]["charges"][charge]
else:
new_metadata["defects"][defect] = old_metadata["defects"][
defect
@@ -308,9 +280,7 @@ def _write_distortion_metadata(
except KeyError:
os.rename( # ensure previous file saved over, even if subset
filepath,
- os.path.join(
- output_path, f"distortion_metadata_{current_datetime}.json"
- ),
+ os.path.join(output_path, f"distortion_metadata_{current_datetime}.json"),
)
warnings.warn(
f"There was a problem when combining old and new metadata files! Will only write "
@@ -360,9 +330,7 @@ def _create_vasp_input(
None
"""
# create folder for defect
- defect_name_wout_charge, charge_state = defect_name.rsplit(
- "_", 1
- ) # `defect_name` includes charge
+ defect_name_wout_charge, charge_state = defect_name.rsplit("_", 1) # `defect_name` includes charge
charge_state = int(charge_state)
test_letters = [
"h",
@@ -381,8 +349,7 @@ def _create_vasp_input(
dir
for letter in test_letters
for dir in os.listdir(output_path)
- if dir
- == f"{defect_name_wout_charge}{letter}_{'+' if charge_state > 0 else ''}{charge_state}"
+ if dir == f"{defect_name_wout_charge}{letter}_{'+' if charge_state > 0 else ''}{charge_state}"
and os.path.isdir(
f"{output_path}/{defect_name_wout_charge}{letter}_{'+' if charge_state > 0 else ''}"
f"{charge_state}"
@@ -404,17 +371,11 @@ def _create_vasp_input(
)
for dir in matching_dirs
)
- if (
- not match_found
- ): # SnB folders in matching_dirs, so check if Unperturbed structures match
+ if not match_found: # SnB folders in matching_dirs, so check if Unperturbed structures match
for dir in matching_dirs:
- with contextlib.suppress(
- Exception
- ): # if Unperturbed structure could not be parsed /
+ with contextlib.suppress(Exception): # if Unperturbed structure could not be parsed /
# compared to distorted_defect_dict, then pass
- prev_unperturbed_struc = Structure.from_file(
- f"{output_path}/{dir}/Unperturbed/POSCAR"
- )
+ prev_unperturbed_struc = Structure.from_file(f"{output_path}/{dir}/Unperturbed/POSCAR")
current_unperturbed_struc = distorted_defect_dict["Unperturbed"][
"Defect Structure"
].copy()
@@ -433,25 +394,23 @@ def _create_vasp_input(
break
if not match_found: # no matching structure found, assume inequivalent defects
- last_letter = [
+ last_letter = next(
letter
for letter in test_letters
for dir in matching_dirs
if dir
== f"{defect_name_wout_charge}{letter}_{'+' if charge_state > 0 else ''}{charge_state}"
- ][0]
+ )
prev_dir_name = (
f"{defect_name_wout_charge}{last_letter}_{'+' if charge_state > 0 else ''}"
f"{charge_state}"
)
if last_letter == "": # rename prev defect folder
new_prev_dir_name = (
- f"{defect_name_wout_charge}a_{'+' if charge_state > 0 else ''}"
- f"{charge_state}"
+ f"{defect_name_wout_charge}a_{'+' if charge_state > 0 else ''}{charge_state}"
)
new_current_dir_name = (
- f"{defect_name_wout_charge}b_{'+' if charge_state > 0 else ''}"
- f"{charge_state}"
+ f"{defect_name_wout_charge}b_{'+' if charge_state > 0 else ''}{charge_state}"
)
warnings.warn(
f"A previously-generated defect distortions folder {prev_dir_name} exists in "
@@ -489,7 +448,7 @@ def _create_vasp_input(
potcar_settings.update(user_potcar_settings or {})
incar_settings = copy.deepcopy(default_incar_settings)
incar_settings.update(user_incar_settings or {})
- single_defect_dict = list(distorted_defect_dict.values())[0]
+ single_defect_dict = next(iter(distorted_defect_dict.values()))
num_elements = len(single_defect_dict["Defect Structure"].composition.elements)
incar_settings.update({"ROPT": ("1e-3 " * num_elements).rstrip()})
@@ -513,9 +472,7 @@ def _create_vasp_input(
distortion,
single_defect_dict,
) in distorted_defect_dict.items(): # for each distortion, create subfolder
- dds._structure = single_defect_dict[
- "Defect Structure"
- ].get_sorted_structure() # ensure sorted
+ dds._structure = single_defect_dict["Defect Structure"].get_sorted_structure() # ensure sorted
dds.poscar_comment = single_defect_dict.get("POSCAR Comment", None)
dds.write_input(
@@ -641,32 +598,24 @@ def _most_common_oxi(element) -> int:
comp_obj = Composition("O")
comp_obj.add_charges_from_oxi_state_guesses()
element_obj = Element(element)
- oxi_probabilities = [
- (k, v) for k, v in comp_obj.oxi_prob.items() if k.element == element_obj
- ]
+ oxi_probabilities = [(k, v) for k, v in comp_obj.oxi_prob.items() if k.element == element_obj]
if oxi_probabilities: # not empty
- most_common = max(oxi_probabilities, key=lambda x: x[1])[
- 0
- ] # breaks if icsd oxi states is empty
+ most_common = max(oxi_probabilities, key=lambda x: x[1])[0] # breaks if icsd oxi states is empty
return most_common.oxi_state
- else:
- if element_obj.common_oxidation_states:
- return element_obj.common_oxidation_states[
- 0
- ] # known common oxidation state
- else: # no known common oxidation state, make guess and warn user
- if element_obj.oxidation_states:
- guess_oxi = element_obj.oxidation_states[0]
- else:
- guess_oxi = 0
- warnings.warn(
- f"No known common oxidation states in pymatgen/ICSD dataset for element "
- f"{element_obj.name}, guessing as {guess_oxi:+}. You should set this in the "
- f"`oxidation_states` input parameter for `Distortions` if this is unreasonable!"
- )
+ if element_obj.common_oxidation_states:
+ return element_obj.common_oxidation_states[0] # known common oxidation state
+
+ # no known common oxidation state, make guess and warn user
+ guess_oxi = element_obj.oxidation_states[0] if element_obj.oxidation_states else 0
+
+ warnings.warn(
+ f"No known common oxidation states in pymatgen/ICSD dataset for element "
+ f"{element_obj.name}, guessing as {guess_oxi:+}. You should set this in the "
+ f"`oxidation_states` input parameter for `Distortions` if this is unreasonable!"
+ )
- return guess_oxi
+ return guess_oxi
def _calc_number_electrons(
@@ -680,8 +629,8 @@ def _calc_number_electrons(
defect species (in `defect_object`) based on `oxidation_states`.
Args:
- defect_object (:obj:`DefectEntry`):
- doped.core.DefectEntry object.
+ defect_entry (:obj:`DefectEntry`):
+ ``doped.core.DefectEntry`` object.
defect_name (:obj:`str`):
Name of the defect species.
oxidation_states (:obj:`dict`):
@@ -727,9 +676,7 @@ def _calc_number_electrons(
else:
raise ValueError(f"`defect_entry` has an invalid `defect_type`: {defect_type}")
- num_electrons = (
- oxidation_states[substitution_specie] - oxidation_states[site_specie]
- )
+ num_electrons = oxidation_states[substitution_specie] - oxidation_states[site_specie]
if verbose:
print(
@@ -756,11 +703,7 @@ def _calc_number_neighbours(num_electrons: int) -> int:
:obj:`int`:
Number of neighbours to distort
"""
- if abs(num_electrons) > 4:
- num_neighbours = abs(8 - abs(num_electrons))
- else:
- num_neighbours = abs(num_electrons)
- return abs(num_neighbours)
+ return abs(8 - abs(num_electrons)) if abs(num_electrons) > 4 else abs(num_electrons)
def _get_voronoi_nodes(
@@ -801,15 +744,14 @@ def _get_voronoi_nodes(
for vertex in voro.vertices:
frac_coords = prim_structure.lattice.get_fractional_coords(vertex)
vnode = PeriodicSite("V-", frac_coords, prim_structure.lattice)
- if np.all([-tol <= coord < 1 + tol for coord in frac_coords]):
- if all(p.distance(vnode) >= tol for p in vnodes):
- vnodes.append(vnode)
+ if np.all([-tol <= coord < 1 + tol for coord in frac_coords]) and all(
+ p.distance(vnode) >= tol for p in vnodes
+ ):
+ vnodes.append(vnode)
# cluster nodes that are within a certain distance of each other
voronoi_coords = [v.frac_coords for v in vnodes]
- dist_matrix = np.array(
- prim_structure.lattice.get_all_distances(voronoi_coords, voronoi_coords)
- )
+ dist_matrix = np.array(prim_structure.lattice.get_all_distances(voronoi_coords, voronoi_coords))
dist_matrix = (dist_matrix + dist_matrix.T) / 2
condensed_m = squareform(dist_matrix)
z = linkage(condensed_m)
@@ -822,13 +764,9 @@ def _get_voronoi_nodes(
frac_coords.append(vnodes[j].frac_coords)
else:
fcoords = vnodes[j].frac_coords
- d, image = prim_structure.lattice.get_distance_and_image(
- frac_coords[0], fcoords
- )
+ d, image = prim_structure.lattice.get_distance_and_image(frac_coords[0], fcoords)
frac_coords.append(fcoords + image)
- merged_vnodes.append(
- PeriodicSite("V-", np.average(frac_coords, axis=0), prim_structure.lattice)
- )
+ merged_vnodes.append(PeriodicSite("V-", np.average(frac_coords, axis=0), prim_structure.lattice))
vnodes = merged_vnodes
# remove nodes less than 0.5 Å from sites in the structure
@@ -841,9 +779,7 @@ def _get_voronoi_nodes(
# map back to the supercell
sm = StructureMatcher(primitive_cell=False, attempt_supercell=True)
mapping = sm.get_supercell_matrix(structure, prim_structure)
- voronoi_struc = Structure.from_sites(
- vnodes
- ) # Structure object with Voronoi nodes as sites
+ voronoi_struc = Structure.from_sites(vnodes) # Structure object with Voronoi nodes as sites
voronoi_struc.make_supercell(mapping) # Map back to the supercell
# check if there was an origin shift between primitive and supercell
@@ -851,17 +787,13 @@ def _get_voronoi_nodes(
regenerated_supercell.make_supercell(mapping)
fractional_shift = sm.get_transformation(structure, regenerated_supercell)[1]
if not np.allclose(fractional_shift, 0):
- voronoi_struc.translate_sites(
- range(len(voronoi_struc)), fractional_shift, frac_coords=True
- )
-
- vnodes = voronoi_struc.sites
+ voronoi_struc.translate_sites(range(len(voronoi_struc)), fractional_shift, frac_coords=True)
- return vnodes
+ return voronoi_struc.sites
def _get_voronoi_multiplicity(site, structure):
- """Get the multiplicity of a Voronoi site in structure"""
+ """Get the multiplicity of a Voronoi site in structure."""
vnodes = _get_voronoi_nodes(structure)
distances_and_species_list = []
@@ -879,11 +811,11 @@ def _get_voronoi_multiplicity(site, structure):
]
sorted_site_distances_and_species = sorted(site_distances_and_species)
- multiplicity = 0
- for distances_and_species in distances_and_species_list:
- if distances_and_species == sorted_site_distances_and_species:
- multiplicity += 1
-
+ multiplicity = sum(
+ 1
+ for distances_and_species in distances_and_species_list
+ if distances_and_species == sorted_site_distances_and_species
+ )
if multiplicity == 0:
warnings.warn(
f"Multiplicity of interstitial at site "
@@ -896,7 +828,11 @@ def _get_voronoi_multiplicity(site, structure):
def identify_defect(
- defect_structure, bulk_structure, defect_coords=None, defect_index=None
+ defect_structure,
+ bulk_structure,
+ defect_coords=None,
+ defect_index=None,
+ oxi_state=None,
) -> Defect:
"""
By comparing the defect and bulk structures, identify the defect present and its site in
@@ -914,6 +850,9 @@ def identify_defect(
Index of the defect site in the supercell. For vacancies, this
should be the site index in the bulk structure, while for substitutions
and interstitials it should be the site index in the defect structure.
+ oxi_state (:obj:`int`, :obj:`float`, :obj:`str`):
+ Oxidation state of the defect site. If not provided, will be
+ automatically determined from the defect structure.
Returns: :obj:`Defect`
"""
@@ -921,9 +860,7 @@ def identify_defect(
# doped if we wanted, but works fine as is.
# identify defect site, structural information, and create defect object:
try:
- defect_type, comp_diff = get_defect_type_and_composition_diff(
- bulk_structure, defect_structure
- )
+ defect_type, comp_diff = get_defect_type_and_composition_diff(bulk_structure, defect_structure)
except RuntimeError as exc:
raise ValueError(
"Could not identify defect type from number of sites in structure: "
@@ -931,26 +868,36 @@ def identify_defect(
) from exc
# remove oxidation states before site-matching
- defect_struc = (
- defect_structure.copy()
- ) # copy to prevent overwriting original structures
+ defect_struc = defect_structure.copy() # copy to prevent overwriting original structures
bulk_struc = bulk_structure.copy()
defect_struc.remove_oxidation_states()
+
+ _bulk_oxi_states = False
+ if oxi_state is None:
+ if all(hasattr(site.specie, "oxi_state") for site in bulk_struc.sites) and all(
+ isinstance(site.specie.oxi_state, (int, float)) for site in bulk_struc.sites
+ ):
+ _bulk_oxi_states = {el.symbol: el.oxi_state for el in bulk_struc.composition.elements}
+ else: # try guessing bulk oxi states now, before Defect initialisation:
+ if _rough_oxi_state_cost_from_comp(bulk_struc.composition) < 1e6:
+ # otherwise will take very long to guess oxi_state
+ queue = Queue()
+ _bulk_oxi_states = _guess_and_set_oxi_states_with_timeout(bulk_struc, queue=queue)
+ if _bulk_oxi_states:
+ bulk_struc = queue.get() # oxi-state decorated structure
+ _bulk_oxi_states = {el.symbol: el.oxi_state for el in bulk_struc.composition.elements}
+
bulk_struc.remove_oxidation_states()
bulk_site_index = None
defect_site_index = None
- if (
- defect_type == "vacancy" and defect_index
- ): # defect_index should correspond to bulk struc
+ if defect_type == "vacancy" and defect_index: # defect_index should correspond to bulk struc
bulk_site_index = defect_index
elif defect_index: # defect_index should correspond to defect struc
if defect_type == "interstitial":
defect_site_index = defect_index
- if (
- defect_type == "substitution"
- ): # also want bulk site index for substitutions,
+ if defect_type == "substitution": # also want bulk site index for substitutions,
# so use defect index coordinates
defect_coords = defect_struc[defect_index].frac_coords
@@ -1012,10 +959,7 @@ def _remove_matching_sites(bulk_site_list, defect_site_list):
max_possible_defect_sites_in_defect_struc,
)
- if (
- len(non_matching_bulk_sites) == 0
- and len(non_matching_defect_sites) == 0
- ):
+ if len(non_matching_bulk_sites) == 0 and len(non_matching_defect_sites) == 0:
warnings.warn(
f"Coordinates {defect_coords} were specified for (auto-determined) "
f"{defect_type} defect, but there are no extra/missing/different species "
@@ -1122,10 +1066,7 @@ def _remove_matching_sites(bulk_site_list, defect_site_list):
if (
defect_site_index is None
and bulk_site_index is None
- and (
- auto_matching_defect_site_index is not None
- or auto_matching_bulk_site_index is not None
- )
+ and (auto_matching_defect_site_index is not None or auto_matching_bulk_site_index is not None)
):
# user didn't specify coordinates or index, but auto site-matching found a defect site
if auto_matching_bulk_site_index is not None:
@@ -1146,13 +1087,10 @@ def _remove_matching_sites(bulk_site_list, defect_site_list):
defect_site = defect_struc[defect_site_index]
if (defect_index is not None or defect_coords is not None) and (
- auto_matching_defect_site_index is not None
- or auto_matching_bulk_site_index is not None
+ auto_matching_defect_site_index is not None or auto_matching_bulk_site_index is not None
):
# user specified site, check if it matched the auto site-matching
- user_index = (
- defect_site_index if defect_site_index is not None else bulk_site_index
- )
+ user_index = defect_site_index if defect_site_index is not None else bulk_site_index
auto_index = (
auto_matching_defect_site_index
if auto_matching_defect_site_index is not None
@@ -1187,11 +1125,15 @@ def _site_info(site):
f"Will use user-specified site: {_site_info(defect_site)}."
)
+ if _bulk_oxi_states:
+ bulk_structure.add_oxidation_state_by_element(_bulk_oxi_states)
+
for_monty_defect = {
"@module": "doped.core",
"@class": defect_type.capitalize(),
"structure": bulk_structure,
"site": defect_site,
+ "oxi_state": oxi_state if _bulk_oxi_states else "Undetermined",
}
try:
defect = MontyDecoder().process_decoded(for_monty_defect)
@@ -1205,14 +1147,14 @@ def _site_info(site):
f"You have the version {v_ana_def} of the package `pymatgen-analysis-defects`,"
" which is incompatible. Please update this package (with `pip install "
"shakenbreak`) and try again."
- )
+ ) from exc
if v_pmg < "2022.7.25":
raise TypeError(
f"You have the version {v_pmg} of the package `pymatgen`, which is incompatible. "
f"Please update this package (with `pip install shakenbreak`) and try again."
- )
- else:
- raise exc
+ ) from exc
+
+ raise exc
return defect
@@ -1243,9 +1185,7 @@ def generate_defect_object(
print(f"Creating defect object for {single_defect_dict['name']}")
defect_type = single_defect_dict["defect_type"]
if defect_type == "antisite":
- defect_type = (
- "substitution" # antisites are represented with Substitution class
- )
+ defect_type = "substitution" # antisites are represented with Substitution class
# Get bulk structure
bulk_structure = bulk_dict["supercell"]["structure"]
# Get defect site
@@ -1269,14 +1209,14 @@ def generate_defect_object(
f"You have the version {v_ana_def} of the package `pymatgen-analysis-defects`,"
" which is incompatible. Please update this package (with `pip install "
"shakenbreak`) and try again."
- )
+ ) from exc
if v_pmg < "2022.7.25":
raise TypeError(
f"You have the version {v_pmg} of the package `pymatgen`, which is incompatible. "
f"Please update this package (with `pip install shakenbreak`) and try again."
- )
- else:
- raise exc
+ ) from exc
+
+ raise exc
# Specify defect charge states
if isinstance(charges, list): # Priority to charges argument
@@ -1302,8 +1242,7 @@ def _get_defects_dict_from_defects_entries(defect_entries):
defect_entry.name.rsplit("_", 1)[0]: [ # defect names without charge
def_entry
for def_entry in defect_entries
- if def_entry.name.rsplit("_", 1)[0]
- == defect_entry.name.rsplit("_", 1)[0]
+ if def_entry.name.rsplit("_", 1)[0] == defect_entry.name.rsplit("_", 1)[0]
]
for defect_entry in defect_entries
}
@@ -1326,22 +1265,17 @@ def _get_defects_dict_from_defects_entries(defect_entries):
defect_entry_list = []
for defect_entry in defect_entries:
if not any(
- sm.fit(
- defect_entry.defect.defect_structure, entry.defect.defect_structure
- )
+ sm.fit(defect_entry.defect.defect_structure, entry.defect.defect_structure)
for entry in defect_entry_list
):
# ensure sc_defect_frac_coords defined:
_find_sc_defect_coords(defect_entry)
defect_entry_list.append(defect_entry)
- defect_entries_dict = name_defect_entries(
- defect_entry_list
- ) # DefectsGenerator.defect_entries
+ defect_entries_dict = name_defect_entries(defect_entry_list) # DefectsGenerator.defect_entries
# format: {"defect_species": DefectEntry} -> convert:
snb_defects_dict = {
- defect_entry_name_wout_charge: []
- for defect_entry_name_wout_charge in defect_entries_dict
+ defect_entry_name_wout_charge: [] for defect_entry_name_wout_charge in defect_entries_dict
}
for name_wout_charge, defect_entry in defect_entries_dict.items():
@@ -1356,8 +1290,9 @@ def _get_defects_dict_from_defects_entries(defect_entries):
def _find_sc_defect_coords(defect_entry):
"""
Find defect fractional coordinates in defect supercell.
+
Targets cases where user generated DefectEntry manually and
- didn't set the `sc_defect_frac_coords` attribute
+ didn't set the `sc_defect_frac_coords` attribute.
Args:
defect_entry (DefectEntry): DefectEntry object
@@ -1446,6 +1381,9 @@ def _apply_rattle_bond_distortions(
distorted_atoms (:obj:`list`, optional):
List of the atomic indices which should undergo bond distortions.
(Default: None)
+ oxidation_states (:obj:`dict`):
+ Dictionary with oxidation states of the atoms in the material (e.g.
+ {"Cd": +2, "Te": -2}).
verbose (:obj:`bool`):
Whether to print distortion information.
(Default: False)
@@ -1526,20 +1464,14 @@ def _apply_rattle_bond_distortions(
)
# Apply rattle to the bond distorted structure
if active_atoms is None:
- distorted_atom_indices = [
- i[0] for i in bond_distorted_defect["distorted_atoms"]
- ] + [
- bond_distorted_defect.get(
- "defect_site_index"
- ) # only adds defect site if not vacancy
+ distorted_atom_indices = [i[0] for i in bond_distorted_defect["distorted_atoms"]] + [
+ bond_distorted_defect.get("defect_site_index") # only adds defect site if not vacancy
] # Note this is VASP indexing here
distorted_atom_indices = [
i - 1 for i in distorted_atom_indices if i is not None
] # remove 'None' if defect is vacancy, and convert to python indexing
rattling_atom_indices = np.arange(0, len(defect_structure))
- idx = np.in1d(
- rattling_atom_indices, distorted_atom_indices
- ) # returns True for matching indices
+ idx = np.in1d(rattling_atom_indices, distorted_atom_indices) # returns True for matching indices
active_atoms = rattling_atom_indices[~idx] # remove matching indices
if local_rattle:
@@ -1614,6 +1546,9 @@ def apply_snb_distortions(
List of the atomic indices which should undergo bond distortions.
If None, the closest neighbours to the defect will be chosen.
(Default: None)
+ oxidation_states (:obj:`dict`):
+ Dictionary with oxidation states of the atoms in the material (e.g.
+ {"Cd": +2, "Te": -2}).
verbose (:obj:`bool`):
Whether to print distortion information.
(Default: False)
@@ -1657,17 +1592,13 @@ def apply_snb_distortions(
seed = mc_rattle_kwargs.pop("seed", None)
if num_nearest_neighbours != 0:
- for distortion in bond_distortions:
- if isinstance(distortion, float):
- distortion = (
- round(distortion, ndigits=3) + 0
- ) # ensure positive zero (not "-0.0%")
+ for raw_distortion in bond_distortions:
+ if isinstance(raw_distortion, float):
+ distortion = round(raw_distortion, ndigits=3) + 0 # ensure positive zero (not "-0.0%")
if verbose:
print(f"--Distortion {distortion:.1%}")
distortion_factor = 1 + distortion
- if (
- not seed
- ): # by default, set seed equal to distortion factor * 100 (e.g. 0.5 -> 50)
+ if not seed: # by default, set seed equal to distortion factor * 100 (e.g. 0.5 -> 50)
# to avoid cases where a particular supercell rattle gets stuck in a local minimum
seed = int(distortion_factor * 100)
@@ -1685,27 +1616,25 @@ def apply_snb_distortions(
oxidation_states=oxidation_states,
**mc_rattle_kwargs,
)
- distorted_defect_dict["distortions"][
- analysis._get_distortion_filename(distortion)
- ] = bond_distorted_defect["distorted_structure"]
+ distorted_defect_dict["distortions"][analysis._get_distortion_filename(distortion)] = (
+ bond_distorted_defect["distorted_structure"]
+ )
distorted_defect_dict["distortion_parameters"] = {
"unique_site": bulk_supercell_site.frac_coords,
"num_distorted_neighbours": num_nearest_neighbours,
"distorted_atoms": bond_distorted_defect["distorted_atoms"],
}
- if bond_distorted_defect.get(
- "defect_site_index"
- ): # only add site index if not vacancy
- distorted_defect_dict["distortion_parameters"][
- "defect_site_index"
- ] = bond_distorted_defect["defect_site_index"]
+ if bond_distorted_defect.get("defect_site_index"): # only add site index if not vacancy
+ distorted_defect_dict["distortion_parameters"]["defect_site_index"] = (
+ bond_distorted_defect["defect_site_index"]
+ )
- elif isinstance(distortion, str) and distortion.lower() == "dimer":
+ elif isinstance(raw_distortion, str) and raw_distortion.lower() == "dimer":
# Apply dimer distortion, with rattling
bond_distorted_defect = _apply_rattle_bond_distortions(
defect_entry=defect_entry,
num_nearest_neighbours=2,
- distortion_factor=distortion,
+ distortion_factor=raw_distortion,
local_rattle=local_rattle,
stdev=stdev,
d_min=d_min,
@@ -1723,15 +1652,11 @@ def apply_snb_distortions(
{
"unique_site": bulk_supercell_site.frac_coords,
"num_distorted_neighbours_in_dimer": 2, # Dimer distortion only affects 2 atoms
- "distorted_atoms_in_dimer": bond_distorted_defect[
- "distorted_atoms"
- ],
+ "distorted_atoms_in_dimer": bond_distorted_defect["distorted_atoms"],
}
)
if defect_site_index: # only add site index if not vacancy
- distorted_defect_dict["distortion_parameters"][
- "defect_site_index"
- ] = defect_site_index
+ distorted_defect_dict["distortion_parameters"]["defect_site_index"] = defect_site_index
else: # when no extra/missing electrons, just rattle the structure
# Likely to be a shallow defect.
@@ -1742,9 +1667,7 @@ def apply_snb_distortions(
frac_coords = None # only for vacancies!
defect_site_index = defect_object.defect_site_index
- if (
- not seed
- ): # by default, set seed equal to distortion factor * 100 (e.g. 0.5 -> 50)
+ if not seed: # by default, set seed equal to distortion factor * 100 (e.g. 0.5 -> 50)
# to avoid cases where a particular supercell rattle gets stuck in a local minimum
seed = 100 # distortion_factor = 1 when no bond distortion, just rattling
@@ -1773,9 +1696,7 @@ def apply_snb_distortions(
"distorted_atoms": None,
}
if defect_site_index: # only add site index if vacancy
- distorted_defect_dict["distortion_parameters"][
- "defect_site_index"
- ] = defect_site_index
+ distorted_defect_dict["distortion_parameters"]["defect_site_index"] = defect_site_index
if "Dimer" in bond_distortions:
# Apply dimer distortion, without rattling
@@ -1784,16 +1705,12 @@ def apply_snb_distortions(
site_index=defect_site_index,
frac_coords=frac_coords,
)
- distorted_defect_dict["distortions"]["Dimer"] = bond_distorted_defect[
- "distorted_structure"
- ]
+ distorted_defect_dict["distortions"]["Dimer"] = bond_distorted_defect["distorted_structure"]
distorted_defect_dict["distortion_parameters"].update(
{
"unique_site": bulk_supercell_site.frac_coords,
"num_distorted_neighbours_in_dimer": 2, # Dimer distortion only affects 2 atoms
- "distorted_atoms_in_dimer": bond_distorted_defect[
- "distorted_atoms"
- ],
+ "distorted_atoms_in_dimer": bond_distorted_defect["distorted_atoms"],
}
)
return distorted_defect_dict
@@ -1925,10 +1842,7 @@ def __init__(
]
# To account for this, here we refactor the list into a dict
if isinstance(defect_entries, list):
- if not all(
- isinstance(defect, (DefectEntry, thermo.DefectEntry))
- for defect in defect_entries
- ):
+ if not all(isinstance(defect, (DefectEntry, thermo.DefectEntry)) for defect in defect_entries):
raise TypeError(
"Some entries in `defect_entries` list are not DefectEntry objects (required "
"format, see docstring). Distortions can also be initialised from "
@@ -1951,9 +1865,7 @@ def __init__(
]
): # doped/PyCDT defect dict
# Check bulk entry in doped/PyCDT defect_dict
- if (
- "bulk" not in defect_entries
- ): # No bulk entry - ask user to provide it
+ if "bulk" not in defect_entries: # No bulk entry - ask user to provide it
raise ValueError(
"Input `defect_entries` dict matches `doped`/`PyCDT` format, but no 'bulk' "
"entry present. Please try again providing a `bulk` entry in `defect_entries`."
@@ -1971,9 +1883,7 @@ def __init__(
)
# Generate a DefectEntry for each charge state
self.defects_dict[defect_dict["name"]] = [
- _get_defect_entry_from_defect(
- defect=defect, charge_state=charge
- )
+ _get_defect_entry_from_defect(defect=defect, charge_state=charge)
for charge in defect_dict["charges"]
]
@@ -1987,8 +1897,7 @@ def __init__(
name.rsplit("_", 1)[0]: [ # name without charge
defect_entry
for defect_entry in defect_entries.values()
- if defect_entry.name.rsplit("_", 1)[0]
- == name.rsplit("_", 1)[0]
+ if defect_entry.name.rsplit("_", 1)[0] == name.rsplit("_", 1)[0]
]
for name in defect_entries
}
@@ -2006,9 +1915,7 @@ def __init__(
"Structures using `Distortions.from_structures()`"
)
- self.defects_dict = (
- defect_entries # {"defect name": [DefectEntry, ...]}
- )
+ self.defects_dict = defect_entries # {"defect name": [DefectEntry, ...]}
elif isinstance(defect_entries, DefectsGenerator):
self.defects_dict = {
@@ -2028,7 +1935,7 @@ def __init__(
f"instead. See `Distortions()` docstring!"
)
- list_of_defect_entries = list(self.defects_dict.values())[0]
+ list_of_defect_entries = next(iter(self.defects_dict.values()))
defect_object = list_of_defect_entries[0].defect
bulk_comp = defect_object.structure.composition
if "stdev" in mc_rattle_kwargs:
@@ -2065,16 +1972,13 @@ def __init__(
def guess_oxidation_states(bulk_comp):
for max_sites in (-1, None):
try:
- guessed_oxidation_states = bulk_comp.oxi_state_guesses(
- max_sites=max_sites
- )[0]
+ guessed_oxidation_states = bulk_comp.oxi_state_guesses(max_sites=max_sites)[0]
if guessed_oxidation_states:
return guessed_oxidation_states
except IndexError:
continue
- else:
- # pmg oxi state guessing can fail for single-element systems, intermetallics etc
- return {elt.symbol: 0 for elt in bulk_comp.elements}
+ # pmg oxi state guessing can fail for single-element systems, intermetallics etc
+ return {elt.symbol: 0 for elt in bulk_comp.elements}
guessed_oxidation_states = guess_oxidation_states(bulk_comp)
@@ -2117,23 +2021,14 @@ def guess_oxidation_states(bulk_comp):
self.bond_distortions.append("Dimer")
bond_distortions.remove("Dimer")
- self.bond_distortions.extend(
- list(np.around(bond_distortions, 3))
- ) # round to 3 decimal places
+ self.bond_distortions.extend(list(np.around(bond_distortions, 3))) # round to 3 decimal places
else:
# If the user does not specify bond_distortions, use
# distortion_increment:
self.distortion_increment = distortion_increment
self.bond_distortions = list(
- np.flip(
- np.around(
- np.arange(0, 0.601, self.distortion_increment), decimals=3
- )
- )
- * -1
- )[:-1] + list(
- np.around(np.arange(0, 0.601, self.distortion_increment), decimals=3)
- )
+ np.flip(np.around(np.arange(0, 0.601, self.distortion_increment), decimals=3)) * -1
+ )[:-1] + list(np.around(np.arange(0, 0.601, self.distortion_increment), decimals=3))
self._mc_rattle_kwargs = mc_rattle_kwargs
@@ -2234,20 +2129,13 @@ def _parse_number_electrons(
if dict_number_electrons_user:
number_electrons = dict_number_electrons_user[defect_name]
else:
- number_electrons = _calc_number_electrons(
- defect_entry, defect_name, oxidation_states
- )
+ number_electrons = _calc_number_electrons(defect_entry, defect_name, oxidation_states)
_bold_print(f"\nDefect: {defect_name}")
if number_electrons < 0:
- _bold_print(
- "Number of extra electrons in neutral state: "
- + f"{abs(number_electrons)}"
- )
+ _bold_print(f"Number of extra electrons in neutral state: {abs(number_electrons)}")
else:
- _bold_print(
- f"Number of missing electrons in neutral state: {number_electrons}"
- )
+ _bold_print(f"Number of missing electrons in neutral state: {number_electrons}")
return number_electrons
def _get_number_distorted_neighbours(
@@ -2278,9 +2166,7 @@ def _print_distortion_info(
stdev: float,
) -> None:
"""Print applied bond distortions and rattle standard deviation."""
- rounded_distortions = [
- f"{round(i,3)+0}" if isinstance(i, float) else i for i in bond_distortions
- ]
+ rounded_distortions = [f"{round(i,3)+0}" if isinstance(i, float) else i for i in bond_distortions]
print(
"Applying ShakeNBreak...",
"Will apply the following bond distortions:",
@@ -2336,13 +2222,9 @@ def _generate_structure_comment(
the CONTCAR file.
"""
frac_coords = self.distortion_metadata["defects"][defect_name]["unique_site"]
- approx_coords = (
- f"~[{frac_coords[0]:.1f},{frac_coords[1]:.1f},{frac_coords[2]:.1f}]"
- )
+ approx_coords = f"~[{frac_coords[0]:.1f},{frac_coords[1]:.1f},{frac_coords[2]:.1f}]"
return (
- str(
- key_distortion.split("_")[-1]
- ) # Get distortion factor (-60.%) or 'Rattled'
+ str(key_distortion.split("_")[-1]) # Get distortion factor (-60.%) or 'Rattled'
+ " N(Distort)="
+ str(
self.distortion_metadata["defects"][defect_name]["charges"][charge][
@@ -2394,9 +2276,7 @@ def _setup_distorted_defect_dict(
} # General info about (neutral) defect
if defect_type == "substitution": # substitutions and antisites
sub_site_in_bulk = defect.defect_site
- distorted_defect_dict[
- "substitution_specie"
- ] = sub_site_in_bulk.specie.symbol
+ distorted_defect_dict["substitution_specie"] = sub_site_in_bulk.specie.symbol
distorted_defect_dict["substitution_specie"] = defect.site.specie.symbol
return distorted_defect_dict
@@ -2437,23 +2317,17 @@ def write_distortion_metadata(
"""
if defect is not None:
distortion_metadata = {
- "distortion_parameters": {
- **self.distortion_metadata["distortion_parameters"]
- },
+ "distortion_parameters": {**self.distortion_metadata["distortion_parameters"]},
"defects": {defect: self.distortion_metadata["defects"][defect]},
}
else:
distortion_metadata = self.distortion_metadata
if charge is not None:
- distortion_metadata = copy.deepcopy(
- distortion_metadata
- ) # don't overwrite original
+ distortion_metadata = copy.deepcopy(distortion_metadata) # don't overwrite original
for defect_name in list(distortion_metadata["defects"].keys()):
distortion_metadata["defects"][defect_name]["charges"] = {
- charge: distortion_metadata["defects"][defect_name]["charges"][
- charge
- ]
+ charge: distortion_metadata["defects"][defect_name]["charges"][charge]
}
_write_distortion_metadata(
@@ -2492,9 +2366,7 @@ def apply_distortions(
}
and dictionary with distortion parameters for each defect.
"""
- self._print_distortion_info(
- bond_distortions=self.bond_distortions, stdev=self.stdev
- )
+ self._print_distortion_info(bond_distortions=self.bond_distortions, stdev=self.stdev)
distorted_defects_dict = {} # Store distorted & undistorted structures
@@ -2560,9 +2432,7 @@ def apply_distortions(
# Remove distortions with inter-atomic distances less than 1 Angstrom if Hydrogen
# not present
- for dist, struct in list(
- defect_distorted_structures["distortions"].items()
- ):
+ for dist, struct in list(defect_distorted_structures["distortions"].items()):
sorted_distances = np.sort(struct.distance_matrix.flatten())
shortest_interatomic_distance = sorted_distances[len(struct)]
if shortest_interatomic_distance < 1.0 and all(
@@ -2579,21 +2449,17 @@ def apply_distortions(
# Add distorted structures to dictionary
distorted_defects_dict[defect_name]["charges"][charge]["structures"] = {
- "Unperturbed": defect_distorted_structures[
- "Unperturbed"
- ].sc_entry.structure,
- "distortions": dict(
- defect_distorted_structures["distortions"].items()
- ),
+ "Unperturbed": defect_distorted_structures["Unperturbed"].sc_entry.structure,
+ "distortions": dict(defect_distorted_structures["distortions"].items()),
}
# Store distortion parameters/info in self.distortion_metadata
- defect_site_index = defect_distorted_structures[
- "distortion_parameters"
- ].get("defect_site_index")
- distorted_atoms = defect_distorted_structures[
- "distortion_parameters"
- ].get("distorted_atoms", None)
+ defect_site_index = defect_distorted_structures["distortion_parameters"].get(
+ "defect_site_index"
+ )
+ distorted_atoms = defect_distorted_structures["distortion_parameters"].get(
+ "distorted_atoms", None
+ )
self.distortion_metadata = self._update_distortion_metadata(
distortion_metadata=self.distortion_metadata,
defect_name=defect_name,
@@ -2605,6 +2471,55 @@ def apply_distortions(
return distorted_defects_dict, self.distortion_metadata
+ def _prepare_distorted_defect_inputs(
+ self,
+ distorted_defects_dict,
+ output_path,
+ include_charge_state=False,
+ ):
+ """
+ Loop through the distorted defect species in ``distorted_defects_dict``,
+ determine their folder names, create the folders and return the corresponding
+ structures and folder names; for usage in file-writing functions.
+
+ Args:
+ distorted_defects_dict (:obj:`dict`):
+ Dictionary with the distorted and undistorted structures
+ for each charge state of each defect.
+ output_path (:obj:`str`):
+ Path to directory where the folders will be written.
+ include_charge_state (:obj:`bool`):
+ If ``True``, also includes the charge states
+ in the dictionary values. (Default: False)
+
+ Returns:
+ :obj:`dict`:
+ Dictionary with the folder paths as keys and the corresponding
+ structures as values, or the structures and charge state if
+ ``include_charge_state`` is ``True``.
+ """
+ struct_folder_dict = {}
+ # loop for each defect in dict
+ for defect_name, defect_dict in distorted_defects_dict.items():
+ # loop for each charge state of defect
+ for charge in defect_dict["charges"]:
+ for dist, struct in zip(
+ [
+ "Unperturbed",
+ *list(defect_dict["charges"][charge]["structures"]["distortions"].keys()),
+ ],
+ [
+ defect_dict["charges"][charge]["structures"]["Unperturbed"],
+ *list(defect_dict["charges"][charge]["structures"]["distortions"].values()),
+ ],
+ ):
+ sign = "+" if charge > 0 else ""
+ folder_path = f"{output_path}/{defect_name}_{sign}{charge}/{dist}"
+ _create_folder(folder_path)
+ struct_folder_dict[folder_path] = (struct, charge) if include_charge_state else struct
+
+ return struct_folder_dict
+
def write_vasp_files(
self,
user_incar_settings: Optional[dict] = None,
@@ -2673,18 +2588,12 @@ def write_vasp_files(
for key_distortion, struct in zip(
[
"Unperturbed",
- ]
- + list(
- defect_dict["charges"][charge_state]["structures"][
- "distortions"
- ].keys()
- ),
- [defect_dict["charges"][charge_state]["structures"]["Unperturbed"]]
- + list(
- defect_dict["charges"][charge_state]["structures"][
- "distortions"
- ].values()
- ),
+ *list(defect_dict["charges"][charge_state]["structures"]["distortions"].keys()),
+ ],
+ [
+ defect_dict["charges"][charge_state]["structures"]["Unperturbed"],
+ *list(defect_dict["charges"][charge_state]["structures"]["distortions"].values()),
+ ],
):
poscar_comment = self._generate_structure_comment(
defect_name=defect_name,
@@ -2698,9 +2607,7 @@ def write_vasp_files(
"Charge State": charge_state,
}
- defect_species = (
- f"{defect_name}_{'+' if charge_state > 0 else ''}{charge_state}"
- )
+ defect_species = f"{defect_name}_{'+' if charge_state > 0 else ''}{charge_state}"
_create_vasp_input(
defect_name=defect_species,
distorted_defect_dict=charged_defect_dict,
@@ -2770,18 +2677,14 @@ def write_espresso_files(
# Update default parameters with user defined values
if pseudopotentials and not write_structures_only:
- default_input_parameters = loadfn(
- os.path.join(MODULE_DIR, "../SnB_input_files/qe_input.yaml")
- )
+ default_input_parameters = loadfn(os.path.join(MODULE_DIR, "../SnB_input_files/qe_input.yaml"))
if input_file and not input_parameters:
input_parameters = io.parse_qe_input(input_file)
if input_parameters:
for section in input_parameters:
for key in input_parameters[section]:
if section in default_input_parameters:
- default_input_parameters[section][key] = input_parameters[
- section
- ][key]
+ default_input_parameters[section][key] = input_parameters[section][key]
else:
default_input_parameters.update(
{section: {key: input_parameters[section][key]}}
@@ -2790,60 +2693,41 @@ def write_espresso_files(
aaa = AseAtomsAdaptor()
# loop for each defect in dict
- for defect_name, defect_dict in distorted_defects_dict.items():
- for charge in defect_dict["charges"]: # loop for each charge state
- for dist, struct in zip(
- [
- "Unperturbed",
- ]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].keys()
- ),
- [defect_dict["charges"][charge]["structures"]["Unperturbed"]]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].values()
- ),
- ):
- atoms = aaa.get_atoms(struct)
- _create_folder(
- f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/{dist}"
- )
-
- if not pseudopotentials or write_structures_only:
- # only write structures
- warnings.warn(
- "Since `pseudopotentials` have not been specified, "
- "will only write input structures."
- )
- ase.io.write(
- filename=f"{output_path}/"
- + f"{defect_name}_{'+' if charge > 0 else ''}{charge}/{dist}/espresso.pwi",
- images=atoms,
- format="espresso-in",
- )
- else:
- # write complete input file
- default_input_parameters["SYSTEM"][
- "tot_charge"
- ] = charge # Update defect charge
-
- calc = Espresso(
- pseudopotentials=pseudopotentials,
- tstress=False,
- tprnfor=True,
- kpts=(1, 1, 1),
- input_data=default_input_parameters,
- )
- calc.write_input(atoms)
- os.replace(
- "./espresso.pwi",
- f"{output_path}/"
- + f"{defect_name}_{'+' if charge > 0 else ''}{charge}/{dist}/espresso.pwi",
- )
+ for folder_path, (
+ struct,
+ charge,
+ ) in self._prepare_distorted_defect_inputs(
+ distorted_defects_dict, output_path, include_charge_state=True
+ ).items():
+ atoms = aaa.get_atoms(struct)
+
+ if not pseudopotentials or write_structures_only:
+ # only write structures
+ warnings.warn(
+ "Since `pseudopotentials` have not been specified, "
+ "will only write input structures."
+ )
+ ase.io.write(
+ filename=f"{folder_path}/espresso.pwi",
+ images=atoms,
+ format="espresso-in",
+ )
+ else:
+ # write complete input file
+ default_input_parameters["SYSTEM"]["tot_charge"] = charge # Update defect charge
+
+ calc = Espresso(
+ pseudopotentials=pseudopotentials,
+ tstress=False,
+ tprnfor=True,
+ kpts=(1, 1, 1),
+ input_data=default_input_parameters,
+ )
+ calc.write_input(atoms)
+ os.replace(
+ "./espresso.pwi",
+ f"{folder_path}/espresso.pwi",
+ )
return distorted_defects_dict, self.distortion_metadata
def write_cp2k_files(
@@ -2881,59 +2765,32 @@ def write_cp2k_files(
if os.path.exists(input_file) and not write_structures_only:
cp2k_input = Cp2kInput.from_file(input_file)
elif (
- os.path.exists(f"{MODULE_DIR}/../SnB_input_files/cp2k_input.inp")
- and not write_structures_only
+ os.path.exists(f"{MODULE_DIR}/../SnB_input_files/cp2k_input.inp") and not write_structures_only
):
warnings.warn(
f"Specified input file {input_file} does not exist! Using"
" default CP2K input file "
"(see shakenbreak/shakenbreak/cp2k_input.inp)"
)
- cp2k_input = Cp2kInput.from_file(
- f"{MODULE_DIR}/../SnB_input_files/cp2k_input.inp"
- )
+ cp2k_input = Cp2kInput.from_file(f"{MODULE_DIR}/../SnB_input_files/cp2k_input.inp")
distorted_defects_dict, self.distortion_metadata = self.apply_distortions(
verbose=verbose,
)
# loop for each defect in dict
- for defect_name, defect_dict in distorted_defects_dict.items():
- # loop for each charge state of defect
- for charge in defect_dict["charges"]:
- if not write_structures_only and cp2k_input:
- cp2k_input.update({"FORCE_EVAL": {"DFT": {"CHARGE": charge}}})
-
- for dist, struct in zip(
- [
- "Unperturbed",
- ]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].keys()
- ),
- [defect_dict["charges"][charge]["structures"]["Unperturbed"]]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].values()
- ),
- ):
- _create_folder(
- f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/{dist}"
- )
- struct.to(
- fmt="cif",
- filename=f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/"
- + f"{dist}/structure.cif",
- )
- if not write_structures_only and cp2k_input:
- cp2k_input.write_file(
- input_filename="cp2k_input.inp",
- output_dir=f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/"
- + f"{dist}",
- )
+ for folder_path, struct in self._prepare_distorted_defect_inputs(
+ distorted_defects_dict, output_path
+ ).items():
+ struct.to(
+ fmt="cif",
+ filename=f"{folder_path}/structure.cif",
+ )
+ if not write_structures_only and cp2k_input:
+ cp2k_input.write_file(
+ input_filename="cp2k_input.inp",
+ output_dir=f"{folder_path}",
+ )
return distorted_defects_dict, self.distortion_metadata
@@ -2974,69 +2831,42 @@ def write_castep_files(
verbose=verbose,
)
aaa = AseAtomsAdaptor()
- warnings.filterwarnings(
- "ignore", ".*Could not determine the version of your CASTEP binary.*"
- )
- warnings.filterwarnings(
- "ignore", ".*Generating CASTEP keywords JSON file... hang on.*"
- )
+ warnings.filterwarnings("ignore", ".*Could not determine the version of your CASTEP binary.*")
+ warnings.filterwarnings("ignore", ".*Generating CASTEP keywords JSON file... hang on.*")
# loop for each defect in dict
- for defect_name, defect_dict in distorted_defects_dict.items():
- # loop for each charge state of defect
- for charge in defect_dict["charges"]:
- for dist, struct in zip(
- [
- "Unperturbed",
- ]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].keys()
- ),
- [defect_dict["charges"][charge]["structures"]["Unperturbed"]]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].values()
- ),
- ):
- atoms = aaa.get_atoms(struct)
- _create_folder(
- f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/{dist}"
+ for folder_path, (
+ struct,
+ charge,
+ ) in self._prepare_distorted_defect_inputs(
+ distorted_defects_dict, output_path, include_charge_state=True
+ ).items():
+ atoms = aaa.get_atoms(struct)
+
+ if write_structures_only:
+ ase.io.write(
+ filename=f"{folder_path}/castep.cell",
+ images=atoms,
+ format="castep-cell",
+ )
+ else:
+ try:
+ calc = Castep(directory=f"{folder_path}")
+ calc.set_kpts({"size": (1, 1, 1), "gamma": True})
+ calc.merge_param(input_file)
+ calc.param.charge = charge # Defect charge state
+ calc.set_atoms(atoms)
+ calc.initialize() # this writes the .param file
+ except Exception:
+ warnings.warn(
+ "Problem setting up the CASTEP `.param` file. "
+ "Only structures will be written "
+ "as `castep.cell` files."
+ )
+ ase.io.write(
+ filename=(f"{folder_path}/castep.cell"),
+ images=atoms,
+ format="castep-cell",
)
-
- if write_structures_only:
- ase.io.write(
- filename=f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/"
- + f"{dist}/castep.cell",
- images=atoms,
- format="castep-cell",
- )
- else:
- try:
- calc = Castep(
- directory=f"{output_path}/"
- + f"{defect_name}_{'+' if charge > 0 else ''}{charge}/{dist}"
- )
- calc.set_kpts({"size": (1, 1, 1), "gamma": True})
- calc.merge_param(input_file)
- calc.param.charge = charge # Defect charge state
- calc.set_atoms(atoms)
- calc.initialize() # this writes the .param file
- except Exception:
- warnings.warn(
- "Problem setting up the CASTEP `.param` file. "
- "Only structures will be written "
- "as `castep.cell` files."
- )
- ase.io.write(
- filename=(
- f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/"
- f"{dist}/castep.cell"
- ),
- images=atoms,
- format="castep-cell",
- )
return distorted_defects_dict, self.distortion_metadata
def write_fhi_aims_files(
@@ -3104,57 +2934,39 @@ def write_fhi_aims_files(
# By default symmetry is not preserved
)
# loop for each defect in dict
- for defect_name, defect_dict in distorted_defects_dict.items():
- # loop for each charge state of defect
- for charge in defect_dict["charges"]:
- if isinstance(ase_calculator, Aims) and not write_structures_only:
- ase_calculator.set(charge=charge) # Defect charge state
-
- # Total number of electrons for net spin initialization
- # Must set initial spin moments (otherwise FHI-aims will
- # lead to 0 final spin)
- struct = defect_dict["charges"][charge]["structures"]["Unperturbed"]
- if struct.composition.total_electrons % 2 == 0:
- # Even number of electrons -> net spin is 0
- ase_calculator.set(default_initial_moment=0)
- else:
- ase_calculator.set(default_initial_moment=1)
+ for folder_path, (
+ struct,
+ charge,
+ ) in self._prepare_distorted_defect_inputs(
+ distorted_defects_dict, output_path, include_charge_state=True
+ ).items():
+ if isinstance(ase_calculator, Aims) and not write_structures_only:
+ ase_calculator.set(charge=charge) # Defect charge state
+
+ # Total number of electrons for net spin initialization
+ # Must set initial spin moments (otherwise FHI-aims will
+ # lead to 0 final spin)
+ if struct.composition.total_electrons % 2 == 0:
+ # Even number of electrons -> net spin is 0
+ ase_calculator.set(default_initial_moment=0)
+ else:
+ ase_calculator.set(default_initial_moment=1)
- for dist, struct in zip(
- [
- "Unperturbed",
- ]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].keys()
- ),
- [defect_dict["charges"][charge]["structures"]["Unperturbed"]]
- + list(
- defect_dict["charges"][charge]["structures"][
- "distortions"
- ].values()
- ),
- ):
- atoms = aaa.get_atoms(struct)
- _create_folder(
- f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}/{dist}"
- )
+ atoms = aaa.get_atoms(struct)
+ dist = folder_path.split("/")[-1]
- ase.io.write(
- filename=f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}"
- + f"/{dist}/geometry.in",
- images=atoms,
- format="aims",
- info_str=dist,
- ) # write input structure file
+ ase.io.write(
+ filename=f"{folder_path}/geometry.in",
+ images=atoms,
+ format="aims",
+ info_str=dist,
+ ) # write input structure file
- if isinstance(ase_calculator, Aims) and not write_structures_only:
- ase_calculator.write_control(
- filename=f"{output_path}/{defect_name}_{'+' if charge > 0 else ''}{charge}"
- + f"/{dist}/control.in",
- atoms=atoms,
- ) # write parameters file
+ if isinstance(ase_calculator, Aims) and not write_structures_only:
+ ase_calculator.write_control(
+ filename=f"{folder_path}/control.in",
+ atoms=atoms,
+ ) # write parameters file
return distorted_defects_dict, self.distortion_metadata
@@ -3186,7 +2998,7 @@ def from_structures(
[(defect Structure, frac_coords/index), ...] to aid site-matching.
Defect charge states (from which bond distortions are determined) are
- set to the range: 0 – {Defect oxidation state}, with a `padding`
+ set to the range: 0 - {Defect oxidation state}, with a `padding`
(default = 1) on either side of this range.
bulk (:obj:`pymatgen.core.structure.Structure`):
Bulk supercell structure, matching defect supercells.
@@ -3198,7 +3010,7 @@ def from_structures(
common oxidation states of any extrinsic species.
padding (:obj:`int`):
Defect charge states are set to the range:
- 0 – {Defect oxidation state}, with a `padding` (default = 1)
+ 0 - {Defect oxidation state}, with a `padding` (default = 1)
on either side of this range.
dict_number_electrons_user (:obj:`dict`):
Optional argument to set the number of extra/missing charge
@@ -3277,56 +3089,42 @@ def from_structures(
)
if not padding:
print(
- "Defect charge states will be set to the range: 0 – {Defect "
- "oxidation state}, with a `padding = 1` on either side of this "
- "range."
+ "Defect charge states will be set to the range: 0 - {Defect oxidation state}, "
+ "with a `padding = 1` on either side of this range."
)
else:
print(
- "Defect charge states will be set to the range: 0 – {Defect "
- "oxidation state}, "
- + f"with a `padding = {padding}` on either side of this "
- "range."
+ "Defect charge states will be set to the range: 0 - {Defect oxidation state}, "
+ "with a `padding = {padding}` on either side of this range."
)
for defect_structure in defects:
if isinstance(defect_structure, Structure):
- defect = identify_defect(
+ if defect := identify_defect(
defect_structure=defect_structure,
bulk_structure=bulk,
- )
- if defect:
+ ):
# Generate a defect entry for each charge state
defect.user_charges = defect.get_charge_states(padding=padding)
for charge in defect.user_charges:
defect_entries.append(
- _get_defect_entry_from_defect(
- defect=defect, charge_state=charge
- )
+ _get_defect_entry_from_defect(defect=defect, charge_state=charge)
)
# Check if user gives dict with structure and defect_coords/defect_index
- elif isinstance(defect_structure, tuple) or isinstance(
- defect_structure, list
- ):
+ elif isinstance(defect_structure, (tuple, list)):
if len(defect_structure) != 2:
raise ValueError(
"If an entry in `defect_entries` is a tuple/list, it must be in the "
"format: (defect Structure, frac_coords/index)"
)
- elif isinstance(defect_structure[1], int) or isinstance(
- defect_structure[1], float
- ): # defect index
+ if isinstance(defect_structure[1], (int, float)): # defect index
defect = identify_defect(
defect_structure=defect_structure[0],
bulk_structure=bulk,
defect_index=int(defect_structure[1]),
)
- elif (
- isinstance(defect_structure[1], list)
- or isinstance(defect_structure[1], tuple)
- or isinstance(defect_structure[1], np.ndarray)
- ):
+ elif isinstance(defect_structure[1], (list, tuple, np.ndarray)):
defect = identify_defect(
defect_structure=defect_structure[0],
bulk_structure=bulk,
@@ -3340,18 +3138,14 @@ def from_structures(
f" {type(defect_structure[1])} instead. "
f"Will proceed with auto-site matching."
)
- defect = identify_defect(
- defect_structure=defect_structure[0], bulk_structure=bulk
- )
+ defect = identify_defect(defect_structure=defect_structure[0], bulk_structure=bulk)
if defect:
defect.user_charges = defect.get_charge_states(padding=padding)
# Generate a defect entry for each charge state:
for charge in defect.user_charges:
defect_entries.append(
- _get_defect_entry_from_defect(
- defect=defect, charge_state=charge
- )
+ _get_defect_entry_from_defect(defect=defect, charge_state=charge)
)
else:
diff --git a/shakenbreak/io.py b/shakenbreak/io.py
index ec311b4..263fd09 100644
--- a/shakenbreak/io.py
+++ b/shakenbreak/io.py
@@ -7,7 +7,7 @@
import datetime
import os
import warnings
-from typing import TYPE_CHECKING, Optional, Union
+from typing import Optional, Union
import ase
from ase.atoms import Atoms
@@ -17,10 +17,6 @@
from pymatgen.core.units import Energy
from pymatgen.io.ase import AseAtomsAdaptor
-if TYPE_CHECKING:
- import pymatgen.core.periodic_table
- import pymatgen.core.structure
-
from shakenbreak import analysis
aaa = AseAtomsAdaptor()
@@ -31,6 +27,7 @@ def parse_energies(
path: Optional[str] = ".",
code: Optional[str] = "vasp",
filename: Optional[str] = "OUTCAR",
+ verbose: bool = False,
) -> None:
"""
Parse final energy for all distortions present in the given defect
@@ -54,6 +51,9 @@ def parse_energies(
(i.e. vasp: "OUTCAR", cp2k: "relax.out", espresso: "espresso.out",
castep: "*.castep", fhi-aims: "aims.out")
Default to the ShakeNBreak default filenames.
+ verbose (:obj:`bool`):
+ If True, print information about renamed/saved-over files.
+ Defaults to False.
Returns: energies_file path.
"""
@@ -71,7 +71,7 @@ def _match(filename, grep_string):
return None
def sort_energies(defect_energies_dict):
- """Order dict items by key (e.g. from -0.6 to 0 to +0.6)"""
+ """Order dict items by key (e.g. from -0.6 to 0 to +0.6)."""
# sort distortions
sorted_energies_dict = {
"distortions": dict(
@@ -89,7 +89,7 @@ def sort_energies(defect_energies_dict):
sorted_energies_dict["Dimer"] = defect_energies_dict["Dimer"]
return sorted_energies_dict
- def save_file(energies, defect, path):
+ def save_file(energies, defect, path, verbose=False):
"""Save yaml file with final energies for each distortion."""
# File to write energies to
filename = f"{path}/{defect}/{defect}.yaml"
@@ -98,11 +98,12 @@ def save_file(energies, defect, path):
old_file = loadfn(filename)
if old_file != energies:
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
- print(
- f"Moving old {filename} to "
- f"{filename.replace('.yaml', '')}_{current_datetime}.yaml "
- "to avoid overwriting"
- )
+ if verbose:
+ print(
+ f"Moving old {filename} to "
+ f"{filename.replace('.yaml', '')}_{current_datetime}.yaml "
+ "to avoid overwriting"
+ )
os.rename(
filename, f"{filename.replace('.yaml', '')}_{current_datetime}.yaml"
) # Keep copy of old file
@@ -117,23 +118,17 @@ def parse_vasp_energy(defect_dir, dist, energy, outcar):
outcar = os.path.join(defect_dir, dist, "OUTCAR")
if outcar: # regrep faster than using Outcar/vasprun class
with contextlib.suppress(IndexError):
- energy = _match(
- outcar, r"entropy=.*energy\(sigma->0\)\s+=\s+([\d\-\.]+)"
- )[0][0][
+ energy = _match(outcar, r"entropy=.*energy\(sigma->0\)\s+=\s+([\d\-\.]+)")[0][0][
0
] # Energy of first match
- converged = _match(
- outcar, "required accuracy"
- ) # check if ionic relaxation converged
+ converged = _match(outcar, "required accuracy") # check if ionic relaxation converged
if not converged:
converged = _match(outcar, "considering this converged")
return converged, energy, outcar
def parse_espresso_energy(defect_dir, dist, energy, espresso_out):
"""Parse Quantum Espresso energy from espresso.out file."""
- if os.path.join(
- defect_dir, dist, "espresso.out"
- ): # Default SnB output filename
+ if os.path.join(defect_dir, dist, "espresso.out"): # Default SnB output filename
espresso_out = os.path.join(defect_dir, dist, "espresso.out")
elif os.path.exists(os.path.join(defect_dir, dist, filename)):
espresso_out = os.path.join(defect_dir, dist, filename)
@@ -152,13 +147,9 @@ def parse_cp2k_energy(defect_dir, dist, energy, cp2k_out):
elif os.path.exists(os.path.join(defect_dir, dist, filename)):
cp2k_out = os.path.join(defect_dir, dist, filename)
if cp2k_out:
- converged = _match(
- cp2k_out, "GEOMETRY OPTIMIZATION COMPLETED"
- ) # check if ionic
+ converged = _match(cp2k_out, "GEOMETRY OPTIMIZATION COMPLETED") # check if ionic
# relaxation is converged
- energy_in_Ha = _match(
- cp2k_out, r"Total energy:\s+([\d\-\.]+)"
- ) # Energy of first
+ energy_in_Ha = _match(cp2k_out, r"Total energy:\s+([\d\-\.]+)") # Energy of first
# match in Hartree
energy = float(Energy(energy_in_Ha[0][0][0], "Ha").to("eV"))
return converged, energy, cp2k_out
@@ -166,9 +157,7 @@ def parse_cp2k_energy(defect_dir, dist, energy, cp2k_out):
def parse_castep_energy(defect_dir, dist, energy, castep_out):
"""Parse CASTEP energy from .castep file."""
converged = False
- output_files = [
- file for file in os.listdir(f"{defect_dir}/{dist}") if ".castep" in file
- ]
+ output_files = [file for file in os.listdir(f"{defect_dir}/{dist}") if ".castep" in file]
if output_files and os.path.exists(f"{defect_dir}/{dist}/{output_files[0]}"):
castep_out = f"{defect_dir}/{dist}/{output_files[0]}"
elif os.path.exists(os.path.join(defect_dir, dist, filename)):
@@ -179,9 +168,7 @@ def parse_castep_energy(defect_dir, dist, energy, castep_out):
# https://www.tcm.phy.cam.ac.uk/castep/Geom_Opt/node20.html
# and https://gitlab.mpcdf.mpg.de/nomad-lab/parser-castep/-/
# blob/master/test/examples/TiO2-geom.castep
- converged = _match(
- castep_out, "Geometry optimization completed successfully."
- )
+ converged = _match(castep_out, "Geometry optimization completed successfully.")
energy = _match(castep_out, r"Final Total Energy\s+([\d\-\.]+)")[0][0][
0
] # Energy of first match in eV
@@ -195,9 +182,7 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
elif os.path.exists(os.path.join(defect_dir, dist, filename)):
aims_out = os.path.join(defect_dir, dist, filename)
if aims_out:
- converged = _match(
- aims_out, "converged."
- ) # check if ionic relaxation is converged
+ converged = _match(aims_out, "converged.") # check if ionic relaxation is converged
# Convergence string deduced from:
# https://fhi-aims-club.gitlab.io/tutorials/basics-of-running-fhi-aims/3-Periodic-Systems/
# and https://gitlab.com/fhi-aims-club/tutorials/basics-of-running-fhi-aims/-/
@@ -210,14 +195,10 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
energy = energy[0][0][0] # Energy of first match in eV
return converged, energy, aims_out
- energies = {
- "distortions": {}
- } # maps each distortion to the energy of the optimised structure
+ energies = {"distortions": {}} # maps each distortion to the energy of the optimised structure
if defect == os.path.basename(os.path.normpath(path)) and not [
- dir
- for dir in path
- if (os.path.isdir(dir) and os.path.basename(os.path.normpath(dir)) == defect)
+ dir for dir in path if (os.path.isdir(dir) and os.path.basename(os.path.normpath(dir)) == defect)
]: # if `defect` is in end of `path` and `path` doesn't have a subdirectory called `defect`
# then remove defect from end of path
path = os.path.dirname(path)
@@ -225,9 +206,7 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
defect_dir = f"{path}/{defect}"
if not os.path.isdir(defect_dir):
orig_defect_name = defect
- defect = defect.replace(
- "+", ""
- ) # try removing '+' from defect name, old format
+ defect = defect.replace("+", "") # try removing '+' from defect name, old format
defect_dir = f"{path}/{defect}" # try removing '+' from defect name, old format
if not os.path.isdir(defect_dir):
@@ -235,7 +214,7 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
f"Defect folder '{orig_defect_name}' not found in '{path}'. Please check these folders "
f"and paths."
)
- return
+ return None
dist_dirs = [
dir
@@ -269,30 +248,20 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
energy = None
converged = False
if code.lower() == "vasp":
- converged, energy, outcar = parse_vasp_energy(
- defect_dir, dist, energy, outcar
- )
+ converged, energy, outcar = parse_vasp_energy(defect_dir, dist, energy, outcar)
elif code.lower() in [
"espresso",
"quantum_espresso",
"quantum-espresso",
"quantumespresso",
]:
- converged, energy, outcar = parse_espresso_energy(
- defect_dir, dist, energy, outcar
- )
+ converged, energy, outcar = parse_espresso_energy(defect_dir, dist, energy, outcar)
elif code.lower() == "cp2k":
- converged, energy, outcar = parse_cp2k_energy(
- defect_dir, dist, energy, outcar
- )
+ converged, energy, outcar = parse_cp2k_energy(defect_dir, dist, energy, outcar)
elif code.lower() == "castep":
- converged, energy, outcar = parse_castep_energy(
- defect_dir, dist, energy, outcar
- )
+ converged, energy, outcar = parse_castep_energy(defect_dir, dist, energy, outcar)
elif code.lower() in ["fhi-aims", "fhi_aims", "fhiaims"]:
- converged, energy, outcar = parse_fhi_aims_energy(
- defect_dir, dist, energy, outcar
- )
+ converged, energy, outcar = parse_fhi_aims_energy(defect_dir, dist, energy, outcar)
if analysis._format_distortion_names(dist) != "Label_not_recognized":
dist_name = analysis._format_distortion_names(dist)
@@ -312,20 +281,14 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
# check if energy not found, but was previously parsed, then add to dict
if dist_name in prev_energies_dict:
energies[dist_name] = prev_energies_dict[dist_name]
- elif (
- "distortions" in prev_energies_dict
- and dist_name in prev_energies_dict["distortions"]
- ):
- energies["distortions"][dist_name] = prev_energies_dict["distortions"][
- dist_name
- ]
+ elif "distortions" in prev_energies_dict and dist_name in prev_energies_dict["distortions"]:
+ energies["distortions"][dist_name] = prev_energies_dict["distortions"][dist_name]
else:
warnings.warn(f"No output file in {dist} directory")
if energies["distortions"]:
if "Unperturbed" in energies and all(
- value - energies["Unperturbed"] > 0.1
- for value in energies["distortions"].values()
+ value - energies["Unperturbed"] > 0.1 for value in energies["distortions"].values()
):
warnings.warn(
f"All distortions parsed for {defect} are >0.1 eV higher energy than unperturbed, "
@@ -334,17 +297,14 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
f"often this is the result of an unreasonable charge state). If both checks pass, "
f"you likely need to adjust the `stdev` rattling parameter (can occur for "
f"hard/ionic/magnetic materials); see "
- f"https://shakenbreak.readthedocs.io/en/latest/Tips.html#hard-ionic-materials. – This "
- f"often indicates a complex PES with multiple minima, thus energy-lowering distortions "
- f"particularly likely, so important to test with reduced `stdev`!"
+ f"https://shakenbreak.readthedocs.io/en/latest/Tips.html#hard-ionic-materials\n"
+ f"This often indicates a complex PES with multiple minima, thus energy-lowering "
+ f"distortions particularly likely, so important to test with reduced `stdev`!"
)
elif (
"Unperturbed" in energies
- and all(
- value - energies["Unperturbed"] < -0.1
- for value in energies["distortions"].values()
- )
+ and all(value - energies["Unperturbed"] < -0.1 for value in energies["distortions"].values())
and len(energies["distortions"]) > 2
):
warnings.warn(
@@ -391,10 +351,6 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
f"are correct, check calculations have converged, and that distortion subfolders match "
f"ShakeNBreak naming (e.g. Bond_Distortion_xxx, Rattled, Unperturbed)"
)
- if "Unperturbed" in energies:
- # TODO: remove next two lines
- energies = sort_energies(energies)
- save_file(energies, defect, path)
# if any entries in prev_energies_dict not in energies_dict, add to energies_dict and
# warn user about output files not being parsed for this entry
@@ -417,7 +373,7 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
energies = sort_energies(energies)
if energies and energies != {"distortions": {}}:
- save_file(energies, defect, path)
+ save_file(energies, defect, path, verbose=verbose)
return energies_file
@@ -474,23 +430,18 @@ def read_espresso_structure(
"""
# ase.io.espresso functions seem a bit buggy, so we use the following implementation
if os.path.exists(filename):
- with open(filename, "r", encoding="utf-8") as f:
+ with open(filename, encoding="utf-8") as f:
file_content = f.read()
else:
warnings.warn(
- f"{filename} file doesn't exist, storing as 'Not converged'. "
- f"Check path & relaxation"
+ f"{filename} file doesn't exist, storing as 'Not converged'. Check path & relaxation"
)
structure = "Not converged"
try:
if "Begin final coordinates" in file_content:
- file_content = file_content.split("Begin final coordinates")[
- -1
- ] # last geometry
+ file_content = file_content.split("Begin final coordinates")[-1] # last geometry
if "End final coordinates" in file_content:
- file_content = file_content.split("End final coordinates")[
- 0
- ] # last geometry
+ file_content = file_content.split("End final coordinates")[0] # last geometry
# Parse cell parameters and atomic positions
cell_lines = [
line
@@ -502,29 +453,17 @@ def read_espresso_structure(
atomic_positions = file_content.split("ATOMIC_POSITIONS (angstrom)")[1]
# Cell parameters
cell_lines_processed = [
- [float(number) for number in line.split()]
- for line in cell_lines
- if len(line.split()) == 3
+ [float(number) for number in line.split()] for line in cell_lines if len(line.split()) == 3
]
# Atomic positions
atomic_positions_processed = [
- [entry for entry in line.split()]
- for line in atomic_positions.split("\n")
- if len(line.split()) >= 4
- ]
- coordinates = [
- [float(entry) for entry in line[1:4]] for line in atomic_positions_processed
- ]
- symbols = [
- entry[0]
- for entry in atomic_positions_processed
- if entry != "" and entry != " " and entry != " "
+ line.split() for line in atomic_positions.split("\n") if len(line.split()) >= 4
]
+ coordinates = [[float(entry) for entry in line[1:4]] for line in atomic_positions_processed]
+ symbols = [entry[0] for entry in atomic_positions_processed if entry not in ["", " ", " "]]
# Check parsing is ok
for entry in coordinates:
- assert (
- len(entry) == 3
- ) # Encure 3 numbers (xyz) are parsed from coordinates section
+ assert len(entry) == 3 # Encure 3 numbers (xyz) are parsed from coordinates section
assert len(symbols) == len(coordinates) # Same number of atoms and coordinates
atoms = Atoms(
symbols=symbols,
@@ -559,21 +498,20 @@ def read_fhi_aims_structure(filename: str, format="aims") -> Union[Structure, st
:obj:`Structure`:
`pymatgen` Structure object
"""
- if os.path.exists(filename):
- try:
- aaa = AseAtomsAdaptor()
- atoms = ase.io.read(filename=filename, format=format)
- structure = aaa.get_structure(atoms)
- structure = structure.get_sorted_structure() # Sort sites by
- # electronegativity
- except Exception:
- warnings.warn(
- f"Problem parsing structure from: {filename}, storing as 'Not "
- f"converged'. Check file & relaxation"
- )
- structure = "Not converged"
- else:
+ if not os.path.exists(filename):
raise FileNotFoundError(f"File {filename} does not exist!")
+ try:
+ aaa = AseAtomsAdaptor()
+ atoms = ase.io.read(filename=filename, format=format)
+ structure = aaa.get_structure(atoms)
+ structure = structure.get_sorted_structure() # Sort sites by
+ # electronegativity
+ except Exception:
+ warnings.warn(
+ f"Problem parsing structure from: {filename}, storing as 'Not "
+ f"converged'. Check file & relaxation"
+ )
+ structure = "Not converged"
return structure
@@ -592,24 +530,23 @@ def read_cp2k_structure(
:obj:`Structure`:
`pymatgen` Structure object
"""
- if os.path.exists(filename):
- try:
- aaa = AseAtomsAdaptor()
- atoms = ase.io.read(
- filename=filename,
- format="cp2k-restart",
- )
- structure = aaa.get_structure(atoms)
- structure = structure.get_sorted_structure() # Sort sites by
- # electronegativity
- except Exception:
- warnings.warn(
- f"Problem parsing structure from: {filename}, storing as 'Not "
- f"converged'. Check file & relaxation"
- )
- structure = "Not converged"
- else:
+ if not os.path.exists(filename):
raise FileNotFoundError(f"File {filename} does not exist!")
+ try:
+ aaa = AseAtomsAdaptor()
+ atoms = ase.io.read(
+ filename=filename,
+ format="cp2k-restart",
+ )
+ structure = aaa.get_structure(atoms)
+ structure = structure.get_sorted_structure() # Sort sites by
+ # electronegativity
+ except Exception:
+ warnings.warn(
+ f"Problem parsing structure from: {filename}, storing as 'Not "
+ f"converged'. Check file & relaxation"
+ )
+ structure = "Not converged"
return structure
@@ -628,24 +565,23 @@ def read_castep_structure(
:obj:`Structure`:
`pymatgen` Structure object
"""
- if os.path.exists(filename):
- try:
- aaa = AseAtomsAdaptor()
- atoms = ase.io.read(
- filename=filename,
- format="castep-castep",
- )
- structure = aaa.get_structure(atoms)
- structure = structure.get_sorted_structure() # Sort sites by
- # electronegativity
- except Exception:
- warnings.warn(
- f"Problem parsing structure from: {filename}, storing as 'Not "
- f"converged'. Check file & relaxation"
- )
- structure = "Not converged"
- else:
+ if not os.path.exists(filename):
raise FileNotFoundError(f"File {filename} does not exist!")
+ try:
+ aaa = AseAtomsAdaptor()
+ atoms = ase.io.read(
+ filename=filename,
+ format="castep-castep",
+ )
+ structure = aaa.get_structure(atoms)
+ structure = structure.get_sorted_structure() # Sort sites by
+ # electronegativity
+ except Exception:
+ warnings.warn(
+ f"Problem parsing structure from: {filename}, storing as 'Not "
+ f"converged'. Check file & relaxation"
+ )
+ structure = "Not converged"
return structure
@@ -734,15 +670,15 @@ def parse_qe_input(path: str) -> dict:
"ATOMIC_FORCES",
"SOLVENTS",
]
- with open(path, "r", encoding="utf-8") as f:
+ with open(path, encoding="utf-8") as f:
lines = f.readlines()
params = {}
- for line in lines:
- line = line.strip().partition("#")[0] # ignore in-line comments
- if line.startswith("&") or any([sec in line for sec in sections]):
+ for raw_line in lines:
+ line = raw_line.strip().partition("#")[0] # ignore in-line comments
+ if line.startswith("&") or any(sec in line for sec in sections):
section = line.split()[0].replace("&", "")
params[section] = {}
- elif line.startswith("#") or line.startswith("!") or line.startswith("/"):
+ elif line.startswith(("#", "!", "/")):
continue
elif "=" in line:
key, value = line.split("=")
@@ -758,9 +694,7 @@ def parse_qe_input(path: str) -> dict:
value.replace('"', "")
params[section][key.strip()] = value
elif len(line.split()) > 1:
- key, value = line.split()[0], " ".join(
- [str(val) for val in line.split()[1:]]
- )
+ key, value = line.split()[0], " ".join([str(val) for val in line.split()[1:]])
params[section][key.strip()] = value
# Remove structure info (if present), as will be re-written with distorted structures
for section in [
@@ -770,7 +704,7 @@ def parse_qe_input(path: str) -> dict:
"CELL_PARAMETERS",
]:
params.pop(section, None)
- if "SYSTEM" in params.keys():
+ if "SYSTEM" in params:
for key in ["celldm(1)", "nat", "ntyp", "ibrav"]:
params["SYSTEM"].pop(key, None)
return params
@@ -789,12 +723,12 @@ def parse_fhi_aims_input(path: str) -> dict:
"""
if not os.path.exists(path):
raise FileNotFoundError(f"File {path} does not exist!")
- with open(path, "r", encoding="utf-8") as f:
+ with open(path, encoding="utf-8") as f:
lines = f.readlines()
params = {}
- for line in lines:
- line = line.strip().partition("#")[0] # ignore in-line comments
- if line.startswith("#") or line.startswith("!") or line.startswith("/"):
+ for raw_line in lines:
+ line = raw_line.strip().partition("#")[0] # ignore in-line comments
+ if line.startswith(("#", "!", "/")):
continue
if len(line.split()) > 1:
if len(line.split()) > 2:
@@ -802,9 +736,7 @@ def parse_fhi_aims_input(path: str) -> dict:
# Convent numeric values to float
# (necessary when feeding into the ASE calculator)
for i in range(len(values)):
- with contextlib.suppress(
- Exception
- ): # Convent numeric values to float
+ with contextlib.suppress(Exception): # Convent numeric values to float
values[i] = float(values[i])
else:
key, values = line.split()[0], line.split()[1]
diff --git a/shakenbreak/plotting.py b/shakenbreak/plotting.py
index b59dfab..8226a27 100644
--- a/shakenbreak/plotting.py
+++ b/shakenbreak/plotting.py
@@ -38,9 +38,7 @@ def _install_custom_font():
try:
# Copy the font file to matplotlib's True Type font directory
fonts_dir = MODULE_DIR
- ttf_fonts = [
- file_name for file_name in os.listdir(fonts_dir) if ".ttf" in file_name
- ]
+ ttf_fonts = [file_name for file_name in os.listdir(fonts_dir) if ".ttf" in file_name]
try:
for font in ttf_fonts: # must be in ttf format for matplotlib
old_path = os.path.join(fonts_dir, font)
@@ -61,9 +59,7 @@ def _install_custom_font():
if os.path.exists(fontList_path):
os.remove(fontList_path)
print("Deleted the matplotlib fontList cache.")
- if not any(
- "fontlist" in file_name.lower() for file_name in mpl_cache_dir_ls
- ):
+ if not any("fontlist" in file_name.lower() for file_name in mpl_cache_dir_ls):
print("Couldn't find matplotlib cache, so will continue.")
# Add font to MAtplotlib Fontmanager
@@ -90,7 +86,7 @@ def _get_backend(save_format: str) -> Optional[str]:
except ImportError:
warnings.warn(
"pycairo not installed. Defaulting to matplotlib's pdf backend, so default "
- "ShakeNBreak fonts may not be used – try setting `save_format` to 'png' or "
+ "ShakeNBreak fonts may not be used -- try setting `save_format` to 'png' or "
"`pip install pycairo` if you want ShakeNBreak's default font."
)
return backend
@@ -104,15 +100,10 @@ def _verify_data_directories_exist(
"""Check top-level directory (e.g. `output_path`) and defect folders exist."""
# Check directories and input
if not os.path.isdir(output_path): # if output_path does not exist, raise error
+ raise FileNotFoundError(f"Path {output_path} does not exist! Skipping {defect_species}.")
+ if not os.path.isdir(f"{output_path}/{defect_species}"): # check if defect directory exists
raise FileNotFoundError(
- f"Path {output_path} does not exist! Skipping {defect_species}."
- )
- if not os.path.isdir(
- f"{output_path}/{defect_species}"
- ): # check if defect directory exists
- raise FileNotFoundError(
- f"Path {output_path}/{defect_species} does not exist! "
- f"Skipping {defect_species}."
+ f"Path {output_path}/{defect_species} does not exist! Skipping {defect_species}."
)
@@ -121,12 +112,10 @@ def _parse_distortion_metadata(distortion_metadata, defect, charge) -> tuple:
Parse the number and type of distorted nearest neighbours for a
given defect from the distortion_metadata dictionary.
"""
- if defect in distortion_metadata["defects"].keys():
+ if defect in distortion_metadata["defects"]:
try:
# Get number and element symbol of the distorted site(s)
- num_nearest_neighbours = distortion_metadata["defects"][defect]["charges"][
- str(charge)
- ][
+ num_nearest_neighbours = distortion_metadata["defects"][defect]["charges"][str(charge)][
"num_nearest_neighbours"
] # get number of distorted neighbours
except KeyError:
@@ -134,9 +123,7 @@ def _parse_distortion_metadata(distortion_metadata, defect, charge) -> tuple:
try:
neighbour_atoms = [ # get element of the distorted site
i[1] # element symbol
- for i in distortion_metadata["defects"][defect]["charges"][str(charge)][
- "distorted_atoms"
- ]
+ for i in distortion_metadata["defects"][defect]["charges"][str(charge)]["distorted_atoms"]
]
if all(element == neighbour_atoms[0] for element in neighbour_atoms):
@@ -176,19 +163,16 @@ def _cast_energies_to_floats(
values as floats.
"""
if not all(
- isinstance(energy, float)
- for energy in list(energies_dict["distortions"].values())
+ isinstance(energy, float) for energy in list(energies_dict["distortions"].values())
) or not isinstance(energies_dict["Unperturbed"], float):
# check energies_dict values are floats
try:
- energies_dict["distortions"] = {
- k: float(v) for k, v in energies_dict["distortions"].items()
- }
+ energies_dict["distortions"] = {k: float(v) for k, v in energies_dict["distortions"].items()}
energies_dict["Unperturbed"] = float(energies_dict["Unperturbed"])
- except ValueError:
+ except ValueError as exc:
raise ValueError(
f"Values of energies_dict are not floats! Skipping {defect_species}."
- )
+ ) from exc
return energies_dict
@@ -215,10 +199,8 @@ def _change_energy_units_to_meV(
if "meV" not in y_label:
y_label = y_label.replace("eV", "meV")
if max_energy_above_unperturbed < 4: # assume eV
- max_energy_above_unperturbed = (
- max_energy_above_unperturbed * 1000
- ) # convert to meV
- for key in energies_dict["distortions"].keys(): # convert to meV
+ max_energy_above_unperturbed = max_energy_above_unperturbed * 1000 # convert to meV
+ for key in energies_dict["distortions"]: # convert to meV
energies_dict["distortions"][key] = energies_dict["distortions"][key] * 1000
energies_dict["Unperturbed"] = energies_dict["Unperturbed"] * 1000
return energies_dict, max_energy_above_unperturbed, y_label
@@ -250,12 +232,12 @@ def _purge_data_dicts(
"""
for key in list(disp_dict.keys()):
if (
- (key not in energies_dict["distortions"].keys() and key != "Unperturbed")
+ (key not in energies_dict["distortions"] and key != "Unperturbed")
or disp_dict[key] == "Not converged"
or disp_dict[key] is None
) and key not in energies_dict[
"distortions"
- ].keys(): # remove it from disp and energy dicts
+ ]: # remove it from disp and energy dicts
disp_dict.pop(key)
return disp_dict, energies_dict
@@ -287,14 +269,9 @@ def _remove_high_energy_points(
"""
for key in list(energies_dict["distortions"].keys()):
# remove high energy points
- if (
- energies_dict["distortions"][key] - energies_dict["Unperturbed"]
- > max_energy_above_unperturbed
- ):
+ if energies_dict["distortions"][key] - energies_dict["Unperturbed"] > max_energy_above_unperturbed:
energies_dict["distortions"].pop(key)
- if (
- disp_dict and key in disp_dict
- ): # only exists if user selected `add_colorbar=True`
+ if disp_dict and key in disp_dict: # only exists if user selected `add_colorbar=True`
disp_dict.pop(key)
return energies_dict, disp_dict
@@ -353,9 +330,7 @@ def _get_displacement_dict(
disp_dict = analysis.calculate_struct_comparison(
defect_structs, metric=metric
) # calculate sum of atomic displacements and maximum displacement between paired sites
- if (
- disp_dict
- ): # if struct_comparison algorithms worked (sometimes struggles matching lattices)
+ if disp_dict: # if struct_comparison algorithms worked (sometimes struggles matching lattices)
disp_dict, energies_dict = _purge_data_dicts(
disp_dict=disp_dict,
energies_dict=energies_dict,
@@ -409,17 +384,12 @@ def _format_datapoints_from_other_chargestates(
# Reformat any "X%_from_Y" or "Rattled_from_Y" distortions to corresponding
# (X) distortion factor or 0.0 for "Rattled"
keys = []
- for entry in energies_dict["distortions"].keys():
+ for entry in energies_dict["distortions"]:
if isinstance(entry, str) and "%_from_" in entry:
keys.append(float(entry.split("%")[0]) / 100)
- elif isinstance(entry, str) and (
- "Rattled_from_" in entry or "Dimer_from_" in entry
- ):
+ elif isinstance(entry, str) and ("Rattled_from_" in entry or "Dimer_from_" in entry):
keys.append(0.0) # Rattled and Dimer will be plotted at x = 0.0
- elif entry == "Rattled": # add 0.0 for Rattled
- # (to avoid problems when sorting distortions)
- keys.append(0.0)
- elif entry == "Dimer": # add 0.0 for Dimer
+ elif entry == "Rattled" or entry == "Dimer": # add 0.0 for Rattled
# (to avoid problems when sorting distortions)
keys.append(0.0)
else:
@@ -428,11 +398,7 @@ def _format_datapoints_from_other_chargestates(
if disp_dict:
# Sort displacements in same order as distortions and energies,
# for proper color mapping
- sorted_disp = [
- disp_dict[k]
- for k in energies_dict["distortions"].keys()
- if k in disp_dict.keys()
- ]
+ sorted_disp = [disp_dict[k] for k in energies_dict["distortions"] if k in disp_dict]
# Save the values of the displacements from *other charge states*
# As the displacements will be re-sorted -> we'll need to
# find the index of t
@@ -511,9 +477,7 @@ def _save_plot(
plot_filepath = f"{os.path.join(defect_dir, defect_name)}.{save_format}"
# If plot already exists, rename to _.
if os.path.exists(plot_filepath):
- current_datetime = datetime.datetime.now().strftime(
- "%Y-%m-%d-%H-%M"
- ) # keep copy of old plot file
+ current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M") # keep copy of old plot file
os.rename(
plot_filepath,
f"{os.path.join(defect_dir, defect_name)}_{current_datetime}.{save_format}",
@@ -581,9 +545,7 @@ def _format_ticks(
loc = mpl.ticker.MultipleLocator(base=tick_interval)
# want the bottom tick to be no more thant (5%)*energy_range above the minimum energy to
# allow easy visual estimation of energy lowering and scale:
- if (
- min(energies_list) + 0.05 * energy_range
- ) // tick_interval == ylim_lower // tick_interval:
+ if (min(energies_list) + 0.05 * energy_range) // tick_interval == ylim_lower // tick_interval:
# means bottom tick is at or above the minimum energy
ylim_lower = min(energies_list) - tick_interval
@@ -633,10 +595,7 @@ def _format_axis(
f"{neighbour_atom} near {defect_name})"
)
elif num_nearest_neighbours and defect_name:
- x_label = (
- f"Bond Distortion Factor (for {num_nearest_neighbours} NN near"
- f" {defect_name})"
- )
+ x_label = f"Bond Distortion Factor (for {num_nearest_neighbours} NN near {defect_name})"
else:
x_label = "Bond Distortion Factor"
ax.set_xlabel(x_label)
@@ -692,9 +651,7 @@ def _setup_colormap(
array_disp = np.array(
[val for val in disp_dict.values() if isinstance(val, float)]
) # ignore "Not converged" or None values
- colormap = sns.cubehelix_palette(
- start=0.65, rot=-0.992075, dark=0.2755, light=0.7205, as_cmap=True
- )
+ colormap = sns.cubehelix_palette(start=0.65, rot=-0.992075, dark=0.2755, light=0.7205, as_cmap=True)
# colormap extremes, mapped to min and max displacements
vmin = round(min(array_disp), 1)
vmax = round(max(array_disp), 1)
@@ -706,11 +663,12 @@ def _setup_colormap(
def _format_colorbar(
fig: mpl.figure.Figure,
ax: mpl.axes.Axes,
- im: mpl.collections.PathCollection,
metric: str,
vmin: float,
vmax: float,
vmedium: float,
+ norm: mpl.colors.Normalize,
+ cmap: mpl.colors.Colormap,
) -> mpl.figure.Figure.colorbar:
"""
Format colorbar of plot.
@@ -720,7 +678,6 @@ def _format_colorbar(
matplotlib.figure.Figure object
ax (:obj:`mpl.axes.Axes`):
current matplotlib.axes.Axes object
- im (:obj:`mpl.collections.PathCollection`)
metric (:obj:`str`):
metric to be plotted: "disp" or "max_dist"
vmin (:obj:`float`):
@@ -729,12 +686,16 @@ def _format_colorbar(
tick label for the colorbar
vmedium (:obj:`float`):
tick label for the colorbar
+ norm (:obj:`mpl.colors.Normalize`):
+ normalization for the colorbar
+ cmap (:obj:`mpl.colors.Colormap`):
+ colormap for the colorbar
Returns:
cbar (:obj:`mpl.colorbar.Colorbar`)
"""
cbar = fig.colorbar(
- im,
+ mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
ax=ax,
boundaries=None,
drawedges=False,
@@ -749,9 +710,7 @@ def _format_colorbar(
cmap_label = r"$\Sigma$ Disp $(\AA)$"
elif metric == "max_dist":
cmap_label = r"$d_{max}$ $(\AA)$"
- cbar.ax.set_title(
- cmap_label, size="medium", loc="center", ha="center", va="center", pad=20.5
- )
+ cbar.ax.set_title(cmap_label, size="medium", loc="center", ha="center", va="center", pad=20.5)
if vmin != vmax:
cbar.set_ticks([vmin, vmedium, vmax])
cbar.set_ticklabels([vmin, vmedium, vmax])
@@ -763,7 +722,7 @@ def _format_colorbar(
# Main plotting functions:
def plot_all_defects(
- defects_dict: dict,
+ defect_charges_dict: dict,
output_path: str = ".",
add_colorbar: bool = False,
metric: str = "max_dist",
@@ -775,13 +734,14 @@ def plot_all_defects(
save_plot: bool = True,
save_format: str = "png",
verbose: bool = False,
+ close_figures: bool = False,
) -> dict:
"""
Convenience function to quickly analyse a range of defects and identify those
which undergo energy-lowering distortions.
Args:
- defects_dict (:obj:`dict`):
+ defect_charges_dict (:obj:`dict`):
Dictionary matching defect names to lists of their charge states.
(e.g {"Int_Sb_1": [0,+1,+2]} etc)
output_path (:obj:`str`):
@@ -824,6 +784,11 @@ def plot_all_defects(
(Default: 'png')
verbose (:obj:`bool`):
Whether to print information about the plots (warnings and where they're saved).
+ close_figures (:obj:`bool`):
+ Whether to close matplotlib figures after saving them, to reduce memory usage.
+ Recommended to use if plotting many defects at once, in which case figures will
+ be saved to disk and not displayed.
+ (Default: False)
Returns:
:obj:`dict`:
@@ -834,25 +799,19 @@ def plot_all_defects(
raise FileNotFoundError(f"Path {output_path} does not exist!")
try:
- distortion_metadata = analysis._read_distortion_metadata(
- output_path=output_path
- )
+ distortion_metadata = analysis._read_distortion_metadata(output_path=output_path)
except FileNotFoundError:
# check if any defect_species folders have distortion_metadata.json files
defect_species_list = [
f"{defect}_{format_charge}"
- for defect in defects_dict
- for charge in defects_dict[defect]
+ for defect in defect_charges_dict
+ for charge in defect_charges_dict[defect]
for format_charge in [charge, f"+{charge}"]
] # allow for defect species names with "+" sign (in SnB > 3.1)
distortion_metadata_list = [
- analysis._read_distortion_metadata(
- output_path=os.path.join(output_path, defect_species)
- )
+ analysis._read_distortion_metadata(output_path=os.path.join(output_path, defect_species))
for defect_species in defect_species_list
- if os.path.isfile(
- os.path.join(output_path, defect_species, "distortion_metadata.json")
- )
+ if os.path.isfile(os.path.join(output_path, defect_species, "distortion_metadata.json"))
]
if distortion_metadata_list:
distortion_metadata = distortion_metadata_list
@@ -866,8 +825,8 @@ def plot_all_defects(
)
distortion_metadata = None
- figures = {}
- for defect, value in defects_dict.items():
+ defects_to_plot = {}
+ for defect, value in defect_charges_dict.items():
for charge in value:
defect_species = f"{defect}_{'+' if charge > 0 else ''}{charge}"
# Parse energies
@@ -886,9 +845,7 @@ def plot_all_defects(
f"Path {energies_file} does not exist. Skipping {defect_species}."
) # skip defect
continue
- energies_dict, energy_diff, gs_distortion = analysis._sort_data(
- energies_file, verbose=False
- )
+ energies_dict, energy_diff, gs_distortion = analysis._sort_data(energies_file, verbose=False)
if not energy_diff: # if Unperturbed calc is not converged, warn user
warnings.warn(
@@ -913,19 +870,13 @@ def plot_all_defects(
(
num_nearest_neighbours,
neighbour_atom,
- ) = _parse_distortion_metadata(
- single_distortion_metadata, defect, charge
- )
- if (
- num_nearest_neighbours is None
- ): # try pull from one of distortion_metadata_list
+ ) = _parse_distortion_metadata(single_distortion_metadata, defect, charge)
+ if num_nearest_neighbours is None: # try pull from one of distortion_metadata_list
for distortion_metadata_dict in distortion_metadata:
(
num_nearest_neighbours,
neighbour_atom,
- ) = _parse_distortion_metadata(
- distortion_metadata_dict, defect, charge
- )
+ ) = _parse_distortion_metadata(distortion_metadata_dict, defect, charge)
if num_nearest_neighbours:
break
@@ -937,24 +888,32 @@ def plot_all_defects(
num_nearest_neighbours = None
neighbour_atom = None
- figures[defect_species] = plot_defect(
- defect_species=defect_species,
- energies_dict=energies_dict,
- output_path=output_path,
- neighbour_atom=neighbour_atom,
- num_nearest_neighbours=num_nearest_neighbours,
- add_colorbar=add_colorbar,
- metric=metric,
- units=units,
- max_energy_above_unperturbed=max_energy_above_unperturbed,
- line_color=line_color,
- add_title=add_title,
- save_plot=save_plot,
- save_format=save_format,
- verbose=verbose,
- )
+ defects_to_plot[defect_species] = {
+ "energies_dict": energies_dict,
+ "num_nearest_neighbours": num_nearest_neighbours,
+ "neighbour_atom": neighbour_atom,
+ }
- return figures
+ return {
+ defect_species: plot_defect(
+ defect_species=defect_species,
+ energies_dict=info_dict.get("energies_dict"),
+ output_path=output_path,
+ neighbour_atom=info_dict.get("neighbour_atom"),
+ num_nearest_neighbours=info_dict.get("num_nearest_neighbours"),
+ add_colorbar=add_colorbar,
+ metric=metric,
+ units=units,
+ max_energy_above_unperturbed=max_energy_above_unperturbed,
+ line_color=line_color,
+ add_title=add_title,
+ save_plot=save_plot,
+ save_format=save_format,
+ verbose=verbose,
+ close_figure=close_figures,
+ )
+ for defect_species, info_dict in defects_to_plot.items()
+ }
def plot_defect(
@@ -974,6 +933,7 @@ def plot_defect(
save_plot: Optional[bool] = True,
save_format: Optional[str] = "png",
verbose: bool = False,
+ close_figure: bool = False,
) -> Optional[Figure]:
"""
Convenience function to plot energy vs distortion for a defect, to identify
@@ -1037,6 +997,11 @@ def plot_defect(
(Default: "png")
verbose (:obj:`bool`):
Whether to print information about the plot (warnings and where it's saved).
+ close_figure (:obj:`bool`):
+ Whether to close matplotlib figure after saving, to reduce memory usage.
+ Recommended to use if plotting many defects at once, in which case figure will
+ be saved to disk and not displayed.
+ (Default: False)
Returns:
:obj:`mpl.figure.Figure`:
@@ -1044,9 +1009,7 @@ def plot_defect(
"""
# Ensure necessary directories exist, and raise error if not
try:
- _verify_data_directories_exist(
- output_path=output_path, defect_species=defect_species
- )
+ _verify_data_directories_exist(output_path=output_path, defect_species=defect_species)
except FileNotFoundError:
if add_colorbar:
warnings.warn(
@@ -1055,11 +1018,10 @@ def plot_defect(
)
add_colorbar = False
- if "Unperturbed" not in energies_dict.keys():
+ if "Unperturbed" not in energies_dict:
# check if unperturbed energies exist
warnings.warn(
- f"Unperturbed energy not present in energies_dict of {defect_species}! "
- f"Skipping plot."
+ f"Unperturbed energy not present in energies_dict of {defect_species}! Skipping plot."
)
return None
@@ -1067,9 +1029,7 @@ def plot_defect(
if not neighbour_atom and not num_nearest_neighbours:
try:
try:
- distortion_metadata = analysis._read_distortion_metadata(
- output_path=output_path
- )
+ distortion_metadata = analysis._read_distortion_metadata(output_path=output_path)
except FileNotFoundError:
distortion_metadata = analysis._read_distortion_metadata(
output_path=f"{output_path}/{defect_species}" # if user moved file
@@ -1176,6 +1136,9 @@ def plot_defect(
save_format=save_format,
verbose=verbose,
)
+ if close_figure:
+ plt.close(fig)
+
return fig
@@ -1184,7 +1147,7 @@ def plot_colorbar(
disp_dict: dict,
defect_species: str,
include_site_info_in_name: Optional[bool] = False,
- num_nearest_neighbours: int = None,
+ num_nearest_neighbours: Optional[int] = None,
neighbour_atom: str = "NN",
title: Optional[str] = None,
legend_label: str = "SnB",
@@ -1315,43 +1278,35 @@ def plot_colorbar(
sorted_distortions,
sorted_energies,
sorted_disp,
- ) = _format_datapoints_from_other_chargestates(
- energies_dict=energies_dict, disp_dict=disp_dict
- )
+ ) = _format_datapoints_from_other_chargestates(energies_dict=energies_dict, disp_dict=disp_dict)
# Plotting
line = None # to later check if line was plotted, for legend formatting
with plt.style.context(f"{MODULE_DIR}/shakenbreak.mplstyle"):
- if (
- "Rattled" in energies_dict["distortions"].keys()
- and "Rattled" in disp_dict.keys()
- ):
+ if "Rattled" in energies_dict["distortions"] and "Rattled" in disp_dict:
# Plot Rattled energy
- im = ax.scatter(
+ path_col = ax.scatter(
0.0,
energies_dict["distortions"]["Rattled"],
- c=disp_dict["Rattled"],
+ c=(disp_dict["Rattled"] if isinstance(disp_dict["Rattled"], float) else "k"),
label="Rattled",
s=50,
marker="o",
- cmap=colormap,
- norm=norm,
+ cmap=colormap if isinstance(disp_dict["Rattled"], float) else None,
+ norm=norm if isinstance(disp_dict["Rattled"], float) else None,
alpha=1,
)
# Plot Dimer
- if (
- "Dimer" in energies_dict["distortions"].keys()
- and "Dimer" in disp_dict.keys()
- ):
- im = ax.scatter(
+ if "Dimer" in energies_dict["distortions"] and "Dimer" in disp_dict:
+ path_col = ax.scatter(
0.0,
energies_dict["distortions"]["Dimer"],
- c=disp_dict["Dimer"],
+ c=disp_dict["Dimer"] if isinstance(disp_dict["Dimer"], float) else "k",
s=50,
marker="s", # default_style_settings["marker"],
label="Dimer",
- cmap=colormap,
- norm=norm,
+ cmap=colormap if isinstance(disp_dict["Dimer"], float) else None,
+ norm=norm if isinstance(disp_dict["Dimer"], float) else None,
alpha=1,
)
@@ -1360,54 +1315,34 @@ def plot_colorbar(
]: # more than just Rattled
if imported_indices: # Exclude datapoints from other charge states
non_imported_sorted_indices = [
- i
- for i in range(len(sorted_distortions))
- if i not in imported_indices.values()
+ i for i in range(len(sorted_distortions)) if i not in imported_indices.values()
]
else:
non_imported_sorted_indices = range(len(sorted_distortions))
# Plot non-imported distortions
- im = ax.scatter( # Points for each distortion
- [
- sorted_distortions[i]
- for i in non_imported_sorted_indices
- if isinstance(sorted_disp[i], float)
- ],
- [
- sorted_energies[i]
- for i in non_imported_sorted_indices
- if isinstance(sorted_disp[i], float)
- ],
- c=[
- sorted_disp[i]
- for i in non_imported_sorted_indices
- if isinstance(sorted_disp[i], float)
- ],
- ls="-",
- s=50,
- marker="o",
- cmap=colormap,
- norm=norm,
- alpha=1,
- )
- ax.scatter( # plot any datapoints where disp could not be determined as black
- [
- sorted_distortions[i]
- for i in non_imported_sorted_indices
- if not isinstance(sorted_disp[i], float)
- ],
- [
- sorted_energies[i]
- for i in non_imported_sorted_indices
- if not isinstance(sorted_disp[i], float)
- ],
- c="k",
- ls="-",
- s=50,
- marker="o",
- alpha=1,
- )
+ non_imported_distortion_indices_with_disp = [
+ i for i in non_imported_sorted_indices if isinstance(sorted_disp[i], float)
+ ]
+ non_imported_distortion_indices_without_disp = [
+ i for i in non_imported_sorted_indices if not isinstance(sorted_disp[i], float)
+ ]
+ for indices_list, color_map in [
+ (non_imported_distortion_indices_without_disp, False),
+ (non_imported_distortion_indices_with_disp, True),
+ ]:
+ if indices_list:
+ path_col = ax.scatter( # plot any datapoints with undetermined disp as black
+ [sorted_distortions[i] for i in indices_list],
+ [sorted_energies[i] for i in indices_list],
+ c=[sorted_disp[i] for i in indices_list] if color_map else "k",
+ ls="-",
+ s=50,
+ marker="o",
+ cmap=colormap if color_map else None,
+ norm=norm if color_map else None,
+ alpha=1,
+ )
if len(non_imported_sorted_indices) > 1: # more than one point
# Plot line connecting points
(line,) = ax.plot(
@@ -1425,18 +1360,16 @@ def plot_colorbar(
other_charges = len(
[ # number of other charge states whose distortions have been imported
list(energies_dict["distortions"].keys())[i].split("_")[-1]
- for i in imported_indices.keys()
+ for i in imported_indices
]
)
for i, j in zip(imported_indices.keys(), range(other_charges)):
- other_charge_state = int(
- list(energies_dict["distortions"].keys())[i].split("_")[-1]
- )
+ other_charge_state = int(list(energies_dict["distortions"].keys())[i].split("_")[-1])
sorted_i = imported_indices[i] # index for the sorted dicts
- ax.scatter(
+ ax.scatter( # plot any datapoints where disp could not be determined as black
np.array(keys)[i],
sorted_energies[sorted_i],
- c=sorted_disp[sorted_i],
+ c=(sorted_disp[sorted_i] if isinstance(sorted_disp[sorted_i], float) else "k"),
edgecolors="k",
ls="-",
s=50,
@@ -1446,8 +1379,8 @@ def plot_colorbar(
j
], # different markers for different charge states
zorder=10, # make sure it's on top of the other points
- cmap=colormap,
- norm=norm,
+ cmap=(colormap if isinstance(sorted_disp[sorted_i], float) else None),
+ norm=norm if isinstance(sorted_disp[sorted_i], float) else None,
alpha=1,
label=f"From {'+' if other_charge_state > 0 else ''}{other_charge_state} "
f"charge state",
@@ -1470,17 +1403,15 @@ def plot_colorbar(
# distortion_range is sorted_distortions range, including 0 if above/below this range
distortion_range = (
- min(sorted_distortions + (0,)),
- max(sorted_distortions + (0,)),
+ min((*sorted_distortions, 0)),
+ max((*sorted_distortions, 0)),
)
# set xlim to distortion_range + 5% (matplotlib default padding), if distortion_range is
# not zero (only rattled and unperturbed)
if distortion_range[1] - distortion_range[0] > 0:
ax.set_xlim(
- distortion_range[0]
- - 0.05 * (distortion_range[1] - distortion_range[0]),
- distortion_range[1]
- + 0.05 * (distortion_range[1] - distortion_range[0]),
+ distortion_range[0] - 0.05 * (distortion_range[1] - distortion_range[0]),
+ distortion_range[1] + 0.05 * (distortion_range[1] - distortion_range[0]),
)
# Formatting of tick labels.
@@ -1488,36 +1419,39 @@ def plot_colorbar(
# 2 if deltaE > 0.1 eV, otherwise 3.
ax = _format_ticks(
ax=ax,
- energies_list=list(energies_dict["distortions"].values())
- + [
+ energies_list=[
+ *list(energies_dict["distortions"].values()),
energies_dict["Unperturbed"],
],
)
- # reformat 'line' legend handle to include 'im' datapoint handle
+ # reformat 'line' legend handle to include 'path_col' datapoint handle
handles, labels = ax.get_legend_handles_labels()
# get handle and label that corresponds to line, if line present:
if line:
- line_handle, line_label = [
- (handle, label)
- for handle, label in zip(handles, labels)
- if label == legend_label
- ][0]
+ line_handle, line_label = next(
+ (handle, label) for handle, label in zip(handles, labels) if label == legend_label
+ )
# remove line handle and label from handles and labels
handles = [handle for handle in handles if handle != line_handle]
labels = [label for label in labels if label != line_label]
# add line handle and label to handles and labels, with datapoint handle
- handles = [(im, line_handle)] + handles
- labels = [line_label] + labels
+ handles = [(path_col, line_handle), *handles]
+ labels = [line_label, *labels]
- plt.legend(
- handles, labels, scatteryoffsets=[0.5], frameon=True, framealpha=0.3
- ).set_zorder(
+ plt.legend(handles, labels, scatteryoffsets=[0.5], frameon=True, framealpha=0.3).set_zorder(
100
) # make sure it's on top of the other points
_ = _format_colorbar(
- fig=fig, ax=ax, im=im, metric=metric, vmin=vmin, vmax=vmax, vmedium=vmedium
+ fig=fig,
+ ax=ax,
+ metric=metric,
+ vmin=vmin,
+ vmax=vmax,
+ vmedium=vmedium,
+ norm=norm,
+ cmap=colormap,
) # Colorbar formatting
# Save plot?
@@ -1640,8 +1574,7 @@ def plot_datasets(
elif len(colors) < len(datasets):
if verbose:
warnings.warn(
- f"Insufficient colors provided for {len(datasets)} datasets. "
- "Using default colors."
+ f"Insufficient colors provided for {len(datasets)} datasets. Using default colors."
)
colors = _get_line_colors(number_of_colors=len(datasets))
# Title and labels of axis
@@ -1664,9 +1597,7 @@ def plot_datasets(
)
# Plot data points for each dataset
- unperturbed_energies = (
- {}
- ) # energies for unperturbed structure obtained with different methods
+ unperturbed_energies = {} # energies for unperturbed structure obtained with different methods
# all energies relative to the unperturbed energy of first dataset
for dataset_number, dataset in enumerate(datasets):
@@ -1699,9 +1630,7 @@ def plot_datasets(
if optional_style_settings: # if set by user
if isinstance(optional_style_settings, list):
try:
- default_style_settings[key] = optional_style_settings[
- dataset_number
- ]
+ default_style_settings[key] = optional_style_settings[dataset_number]
except IndexError:
default_style_settings[key] = optional_style_settings[
0
@@ -1715,11 +1644,9 @@ def plot_datasets(
keys,
sorted_distortions,
sorted_energies,
- ) = _format_datapoints_from_other_chargestates(
- energies_dict=dataset, disp_dict=None
- )
+ ) = _format_datapoints_from_other_chargestates(energies_dict=dataset, disp_dict=None)
with plt.style.context(f"{MODULE_DIR}/shakenbreak.mplstyle"):
- if "Rattled" in dataset["distortions"].keys():
+ if "Rattled" in dataset["distortions"]:
ax.scatter( # Scatter plot for Rattled (1 datapoint)
0.0,
dataset["distortions"]["Rattled"],
@@ -1730,7 +1657,7 @@ def plot_datasets(
)
# Plot Dimer
- if "Dimer" in dataset["distortions"].keys():
+ if "Dimer" in dataset["distortions"]:
ax.scatter( # Scatter plot for Rattled (1 datapoint)
0.0,
dataset["distortions"]["Dimer"],
@@ -1741,15 +1668,11 @@ def plot_datasets(
)
if len(sorted_distortions) > 0 and [
- key
- for key in dataset["distortions"]
- if (key != "Rattled" and key != "Dimer")
+ key for key in dataset["distortions"] if (key != "Rattled" and key != "Dimer")
]: # more than just Rattled
if imported_indices: # Exclude datapoints from other charge states
non_imported_sorted_indices = [
- i
- for i in range(len(sorted_distortions))
- if i not in imported_indices.values()
+ i for i in range(len(sorted_distortions)) if i not in imported_indices.values()
]
else:
non_imported_sorted_indices = range(len(sorted_distortions))
@@ -1770,14 +1693,11 @@ def plot_datasets(
if imported_indices:
other_charges = len(
[
- list(dataset["distortions"].keys())[i].split("_")[-1]
- for i in imported_indices
+ list(dataset["distortions"].keys())[i].split("_")[-1] for i in imported_indices
] # number of other charge states whose distortions have been imported
)
for i, j in zip(imported_indices, range(other_charges)):
- other_charge_state = int(
- list(dataset["distortions"].keys())[i].split("_")[-1]
- )
+ other_charge_state = int(list(dataset["distortions"].keys())[i].split("_")[-1])
ax.scatter( # distortions from other charge states
np.array(keys)[i],
list(dataset["distortions"].values())[i],
@@ -1796,15 +1716,11 @@ def plot_datasets(
f"charge state",
)
- datasets[0][
- "Unperturbed"
- ] = 0.0 # unperturbed energy of first dataset (our reference energy)
+ datasets[0]["Unperturbed"] = 0.0 # unperturbed energy of first dataset (our reference energy)
# Plot Unperturbed point for every dataset, relative to the unperturbed energy of first dataset
for key, value in unperturbed_energies.items():
- if (
- abs(value) > 0.1
- ): # Only plot if different energy from the reference Unperturbed
+ if abs(value) > 0.1: # Only plot if different energy from the reference Unperturbed
print(
f"Energies for unperturbed structures obtained with different methods "
f"({dataset_labels[key]}) differ by {value:.2f}. If testing different "
@@ -1832,7 +1748,15 @@ def plot_datasets(
)
# distortion_range is sorted_distortions range, including 0 if above/below this range
- distortion_range = (min(sorted_distortions + (0,)), max(sorted_distortions + (0,)))
+ distortion_range = (
+ min((*sorted_distortions, 0)),
+ max(
+ (
+ *sorted_distortions,
+ 0,
+ )
+ ),
+ )
# set xlim to distortion_range + 5% (matplotlib default padding)
if distortion_range[1] - distortion_range[0] > 0:
ax.set_xlim(
@@ -1845,27 +1769,21 @@ def plot_datasets(
# > 0.4 eV, 3 if E < 0.1 eV, 2 otherwise
ax = _format_ticks(
ax=ax,
- energies_list=list(datasets[0]["distortions"].values())
- + [
+ energies_list=[
+ *list(datasets[0]["distortions"].values()),
datasets[0]["Unperturbed"],
],
)
# If several datasets, check min & max energy are included
if len(datasets) > 1:
- min_energy = min(
- min(list(dataset["distortions"].values())) for dataset in datasets
- )
- max_energy = max(
- max(list(dataset["distortions"].values())) for dataset in datasets
- )
+ min_energy = min(min(list(dataset["distortions"].values())) for dataset in datasets)
+ max_energy = max(max(list(dataset["distortions"].values())) for dataset in datasets)
ax.set_ylim(
min_energy - 0.1 * (max_energy - min_energy),
max_energy + 0.1 * (max_energy - min_energy),
)
- ax.legend(frameon=True, framealpha=0.3).set_zorder(
- 100
- ) # show legend on top of all other datapoints
+ ax.legend(frameon=True, framealpha=0.3).set_zorder(100) # show legend on top of all other datapoints
if save_plot: # Save plot?
_save_plot(
diff --git a/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input.inp b/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input.inp
index 78c4776..113b42d 100644
--- a/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input.inp
+++ b/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input.inp
@@ -10,8 +10,8 @@
&DFT
BASIS_SET_FILE_NAME HFX_BASIS
POTENTIAL_FILE_NAME GTH_POTENTIALS
- SPIN_POLARIZED .TRUE.
CHARGE 0
+ SPIN_POLARIZED .TRUE.
&MGRID
CUTOFF [eV] 500 ! PW cutoff
&END MGRID
diff --git a/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input_user_parameters.inp b/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input_user_parameters.inp
index aab2955..7f9c879 100644
--- a/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input_user_parameters.inp
+++ b/tests/data/cp2k/vac_1_Cd_0/Bond_Distortion_30.0%/cp2k_input_user_parameters.inp
@@ -10,8 +10,8 @@
&DFT
BASIS_SET_FILE_NAME HFX_BASIS
POTENTIAL_FILE_NAME GTH_POTENTIALS
- SPIN_POLARIZED .FALSE.
CHARGE 0
+ SPIN_POLARIZED .FALSE.
&MGRID
CUTOFF [eV] 800 ! PW cutoff
&END MGRID
diff --git a/tests/data/vasp/CdTe/vac_1_Cd_0/default_INCAR b/tests/data/vasp/CdTe/vac_1_Cd_0/default_INCAR
index 4bd862e..9032123 100644
--- a/tests/data/vasp/CdTe/vac_1_Cd_0/default_INCAR
+++ b/tests/data/vasp/CdTe/vac_1_Cd_0/default_INCAR
@@ -1,4 +1,4 @@
-# May want to change NCORE, KPAR, AEXX, ENCUT, IBRION, LREAL, NUPDOWN, ISPIN = Typical variable parameters
+# May want to change NCORE, KPAR, AEXX, ENCUT, IBRION, LREAL, NUPDOWN, ISPIN, MAGMOM = Typical variable parameters
# ShakeNBreak INCAR with coarse settings to maximise speed with sufficient accuracy for qualitative structure searching =
AEXX = 0.25
ALGO = Normal # change to all if zhegv, fexcp/f or zbrent errors encountered (done automatically by snb-run)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 4151b46..df7f78a 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -261,7 +261,9 @@ def test_snb_generate(self):
catch_exceptions=False,
)
print([str(warning.message) for warning in w]) # for debugging
- non_potcar_warnings = [warning for warning in w if "POTCAR" not in str(warning.message)]
+ non_potcar_warnings = [
+ warning for warning in w if "POTCAR" not in str(warning.message)
+ ]
assert not non_potcar_warnings # no warnings other than POTCAR warnings
self.assertEqual(result.exit_code, 0)
self.assertIn(
@@ -319,12 +321,11 @@ def test_snb_generate(self):
)
kpoints = Kpoints.from_file(f"{V_Cd_Bond_Distortion_folder}/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert filecmp.cmp(
- f"{V_Cd_Bond_Distortion_folder}/INCAR", self.V_Cd_INCAR_file
- )
+ generated_incar = Incar.from_file(f"{V_Cd_Bond_Distortion_folder}/INCAR")
+ assert generated_incar == self.V_Cd_INCAR
# check if POTCARs have been written:
potcar = Potcar.from_file(f"{V_Cd_Bond_Distortion_folder}/POTCAR")
@@ -440,9 +441,13 @@ def test_snb_generate(self):
catch_exceptions=False,
)
print([str(warning.message) for warning in w]) # for debugging
- non_potcar_warnings = [warning for warning in w if "POTCAR" not in str(warning.message)]
+ non_potcar_warnings = [
+ warning for warning in w if "POTCAR" not in str(warning.message)
+ ]
assert len(non_potcar_warnings) == 1 # only overwriting structures warning
- assert "has the same Unperturbed defect structure" in str(non_potcar_warnings[0].message)
+ assert "has the same Unperturbed defect structure" in str(
+ non_potcar_warnings[0].message
+ )
self.assertEqual(result.exit_code, 0)
self.assertIn(f"Defect: {defect_name}", result.output)
self.assertIn("Number of missing electrons in neutral state: 2", result.output)
@@ -491,9 +496,13 @@ def test_snb_generate(self):
catch_exceptions=False,
)
print([str(warning.message) for warning in w]) # for debugging
- non_potcar_warnings = [warning for warning in w if "POTCAR" not in str(warning.message)]
+ non_potcar_warnings = [
+ warning for warning in w if "POTCAR" not in str(warning.message)
+ ]
assert len(non_potcar_warnings) == 1 # only overwriting structures warning
- assert "has the same Unperturbed defect structure" in str(non_potcar_warnings[0].message)
+ assert "has the same Unperturbed defect structure" in str(
+ non_potcar_warnings[0].message
+ )
self.assertEqual(result.exit_code, 0)
self.assertIn(f"Defect: {defect_name}", result.output)
self.assertIn("Number of missing electrons in neutral state: 2", result.output)
@@ -551,9 +560,13 @@ def test_snb_generate(self):
catch_exceptions=False,
)
print([str(warning.message) for warning in w]) # for debugging
- non_potcar_warnings = [warning for warning in w if "POTCAR" not in str(warning.message)]
+ non_potcar_warnings = [
+ warning for warning in w if "POTCAR" not in str(warning.message)
+ ]
assert len(non_potcar_warnings) == 1 # only overwriting structures warning
- assert "has the same Unperturbed defect structure" in str(non_potcar_warnings[0].message)
+ assert "has the same Unperturbed defect structure" in str(
+ non_potcar_warnings[0].message
+ )
self.assertEqual(result.exit_code, 0)
self.assertIn(f"Defect: {defect_name}", result.output)
self.assertIn("Number of missing electrons in neutral state: 2", result.output)
@@ -607,9 +620,13 @@ def test_snb_generate(self):
catch_exceptions=False,
)
print([str(warning.message) for warning in w]) # for debugging
- non_potcar_warnings = [warning for warning in w if "POTCAR" not in str(warning.message)]
+ non_potcar_warnings = [
+ warning for warning in w if "POTCAR" not in str(warning.message)
+ ]
assert len(non_potcar_warnings) == 1 # only overwriting structures warning
- assert "has the same Unperturbed defect structure" in str(non_potcar_warnings[0].message)
+ assert "has the same Unperturbed defect structure" in str(
+ non_potcar_warnings[0].message
+ )
self.assertEqual(result.exit_code, 0)
self.assertIn(f"Defect: {defect_name}", result.output)
self.assertIn("Number of missing electrons in neutral state: 2", result.output)
@@ -630,14 +647,18 @@ def test_snb_generate(self):
)
self.assertEqual(
len(
- reloaded_distortion_metadata["defects"]["v_Cd_Td_Te2.83"]["charges"]["0"][
- "distortion_parameters"
- ]["bond_distortions"]
+ reloaded_distortion_metadata["defects"]["v_Cd_Td_Te2.83"]["charges"][
+ "0"
+ ]["distortion_parameters"]["bond_distortions"]
),
25,
) # no duplication of bond distortions
- defect_folder_distortion_metadata = loadfn("v_Cd_Td_Te2.83_0/distortion_metadata.json")
- self.assertNotEqual(reloaded_distortion_metadata, defect_folder_distortion_metadata)
+ defect_folder_distortion_metadata = loadfn(
+ "v_Cd_Td_Te2.83_0/distortion_metadata.json"
+ )
+ self.assertNotEqual(
+ reloaded_distortion_metadata, defect_folder_distortion_metadata
+ )
# test defect_index option:
self.tearDown()
@@ -1070,7 +1091,7 @@ def test_snb_generate(self):
)
# check print info message:
# self.assertIn(
- # "Defect charge states will be set to the range: 0 – {Defect oxidation "
+ # "Defect charge states will be set to the range: 0 - {Defect oxidation "
# "state}, with a `padding = 1` on either side of this range.",
# result.output,
# )
@@ -1109,7 +1130,7 @@ def test_snb_generate(self):
)
# check print info message:
self.assertIn(
- "Defect charge states will be set to the range: 0 – {Defect oxidation "
+ "Defect charge states will be set to the range: 0 - {Defect oxidation "
"state}, with a `padding = 4` on either side of this range.",
result.output,
)
@@ -1169,7 +1190,7 @@ def test_snb_generate_config(self):
V_Cd_kwarged_POSCAR.structure, self.V_Cd_minus0pt5_struc_kwarged
)
kpoints = Kpoints.from_file(f"{defect_name}_0/Bond_Distortion_-50.0%/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
assert filecmp.cmp(
@@ -1499,7 +1520,7 @@ def test_snb_generate_config(self):
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
"Defect charges were specified using the CLI option, but `charges` "
- "was also specified in the `--config` file – this will be ignored!",
+ "was also specified in the `--config` file -- this will be ignored!",
str(w[0].message),
)
self.tearDown()
@@ -1613,15 +1634,13 @@ def test_snb_generate_all(self):
self.V_Cd_0pt3_local_rattled,
)
kpoints = Kpoints.from_file(f"{defect_name}_0/Bond_Distortion_30.0%/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp(
- f"{defect_name}_0/Bond_Distortion_30.0%/INCAR", self.V_Cd_INCAR_file
- )
+ v_Cd_INCAR = Incar.from_file(f"{defect_name}_0/Bond_Distortion_30.0%/INCAR")
+ assert v_Cd_INCAR != self.V_Cd_INCAR
# NELECT has changed due to POTCARs
- v_Cd_INCAR = Incar.from_file(f"{defect_name}_0/Bond_Distortion_30.0%/INCAR")
v_Cd_INCAR.pop("NELECT")
test_INCAR = self.V_Cd_INCAR.copy()
test_INCAR.pop("NELECT")
@@ -1930,7 +1949,7 @@ def test_snb_generate_all(self):
)
# check print info message:
self.assertIn(
- "Defect charge states will be set to the range: 0 – {Defect oxidation "
+ "Defect charge states will be set to the range: 0 - {Defect oxidation "
"state}, with a `padding = 4` on either side of this range.",
result.output,
)
@@ -2798,39 +2817,52 @@ def _test_OUTCAR_error(error_string):
) # new files as calc being rerun, but without changing ISPIN
def test_parse(self):
- """Test parse() function.
- Implicitly, this also tests the io.parse_energies() function"""
+ """
+ Test parse() function.
+ Implicitly, this also tests the io.parse_energies() function.
+ """
# Specifying defect to parse
# All OUTCAR's present in distortion directories
# Energies file already present
defect = "v_Ti_0"
- with open(f"{self.EXAMPLE_RESULTS}/{defect}/{defect}.yaml", "w") as f:
- f.write("")
runner = CliRunner()
- result = runner.invoke(
- snb,
- [
- "parse",
- "-d",
- defect,
- "-p",
- self.EXAMPLE_RESULTS,
- ],
- catch_exceptions=False,
- )
- self.assertIn(
- f"Moving old {self.EXAMPLE_RESULTS}/{defect}/{defect}.yaml to ",
- result.output,
- )
- energies = loadfn(f"{self.EXAMPLE_RESULTS}/{defect}/{defect}.yaml")
- test_energies = {
- "distortions": {
- -0.4: -1176.28458753,
- },
- "Unperturbed": -1173.02056574,
- } # Using dictionary here (rather than file/string), because parsing order
- # is difference on github actions
- self.assertEqual(test_energies, energies)
+
+ def _parse_v_Ti_and_check_output(verbose=False):
+ with open(f"{self.EXAMPLE_RESULTS}/{defect}/{defect}.yaml", "w") as f:
+ f.write(
+ ""
+ ) # write empty energies file, otherwise no verbose print because detects
+ # that it already exists with the same energies
+ args = ["parse", "-d", defect, "-p", self.EXAMPLE_RESULTS]
+ if verbose:
+ args.append("-v")
+ result = runner.invoke(snb, args, catch_exceptions=True)
+ print(f"Output: {result.output}")
+ if verbose:
+ self.assertIn(
+ f"Moving old {self.EXAMPLE_RESULTS}/{defect}/{defect}.yaml to ",
+ result.output,
+ )
+ else:
+ self.assertNotIn(
+ f"Moving old {self.EXAMPLE_RESULTS}/{defect}/{defect}.yaml to ",
+ result.output,
+ )
+ energies = loadfn(f"{self.EXAMPLE_RESULTS}/{defect}/{defect}.yaml")
+ test_energies = {
+ "distortions": {
+ -0.4: -1176.28458753,
+ },
+ "Unperturbed": -1173.02056574,
+ } # Using dictionary here (rather than file/string), because parsing order
+ # is difference on GitHub actions
+ self.assertEqual(test_energies, energies)
+
+ print("Testing snb-parse non verbose")
+ _parse_v_Ti_and_check_output()
+ print("Testing snb-parse, -v")
+ _parse_v_Ti_and_check_output(verbose=True)
+
[
os.remove(f"{self.EXAMPLE_RESULTS}/{defect}/{file}")
for file in os.listdir(f"{self.EXAMPLE_RESULTS}/{defect}")
@@ -3205,8 +3237,8 @@ def test_parse(self):
f"unreasonable charge state). If both checks pass, you likely need to adjust "
f"the `stdev` rattling parameter (can occur for hard/ionic/magnetic "
f"materials); see "
- f"https://shakenbreak.readthedocs.io/en/latest/Tips.html#hard-ionic-materials. "
- f"– This often indicates a complex PES with multiple minima, "
+ f"https://shakenbreak.readthedocs.io/en/latest/Tips.html#hard-ionic-materials\n"
+ f"This often indicates a complex PES with multiple minima, "
f"thus energy-lowering distortions particularly likely, so important to "
f"test with reduced `stdev`!" == str(i.message)
for i in w
@@ -3941,8 +3973,10 @@ def test_plot(self):
# test distortion_metadata parsed fine when in defect folder (not above):
with open(
- f"{self.EXAMPLE_RESULTS}/{defect_name}_defect_folder/{defect_name}/"
- f"distortion_metadata.json", "w") as f:
+ f"{self.EXAMPLE_RESULTS}/{defect_name}_defect_folder/{defect_name}/"
+ f"distortion_metadata.json",
+ "w",
+ ) as f:
f.write(json.dumps(fake_distortion_metadata, indent=4))
with warnings.catch_warnings(record=True) as w:
result = runner.invoke(
diff --git a/tests/test_input.py b/tests/test_input.py
index 391a80c..2616d72 100644
--- a/tests/test_input.py
+++ b/tests/test_input.py
@@ -1,7 +1,6 @@
import contextlib
import copy
import datetime
-import filecmp
import locale
import os
import shutil
@@ -13,13 +12,12 @@
from ase.build import bulk, make_supercell
from ase.calculators.aims import Aims
from ase.io import read
-from doped import _ignore_pmg_warnings
from doped.generation import get_defect_name_from_entry
from doped.vasp import _test_potcar_functional_choice, DefectRelaxSet
from monty.serialization import dumpfn, loadfn
from pymatgen.analysis.defects.generators import VacancyGenerator
from pymatgen.analysis.defects.thermo import DefectEntry
-from pymatgen.core.periodic_table import DummySpecies, Species
+from pymatgen.core.periodic_table import DummySpecies
from pymatgen.core.structure import Composition, PeriodicSite, Structure
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.io.ase import AseAtomsAdaptor
@@ -105,55 +103,58 @@ class InputTestCase(unittest.TestCase):
def setUp(self):
warnings.filterwarnings("ignore", category=UnknownPotcarWarning)
- self.DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
- self.VASP_DIR = os.path.join(self.DATA_DIR, "vasp")
- self.VASP_CDTE_DATA_DIR = os.path.join(self.DATA_DIR, "vasp/CdTe")
- self.CASTEP_DATA_DIR = os.path.join(self.DATA_DIR, "castep")
- self.CP2K_DATA_DIR = os.path.join(self.DATA_DIR, "cp2k")
- self.FHI_AIMS_DATA_DIR = os.path.join(self.DATA_DIR, "fhi_aims")
- self.ESPRESSO_DATA_DIR = os.path.join(self.DATA_DIR, "quantum_espresso")
- self.CdTe_bulk_struc = Structure.from_file(
- os.path.join(self.VASP_CDTE_DATA_DIR, "CdTe_Bulk_Supercell_POSCAR")
+
+ @classmethod
+ def setUpClass(cls):
+ cls.DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
+ cls.VASP_DIR = os.path.join(cls.DATA_DIR, "vasp")
+ cls.VASP_CDTE_DATA_DIR = os.path.join(cls.DATA_DIR, "vasp/CdTe")
+ cls.CASTEP_DATA_DIR = os.path.join(cls.DATA_DIR, "castep")
+ cls.CP2K_DATA_DIR = os.path.join(cls.DATA_DIR, "cp2k")
+ cls.FHI_AIMS_DATA_DIR = os.path.join(cls.DATA_DIR, "fhi_aims")
+ cls.ESPRESSO_DATA_DIR = os.path.join(cls.DATA_DIR, "quantum_espresso")
+ cls.CdTe_bulk_struc = Structure.from_file(
+ os.path.join(cls.VASP_CDTE_DATA_DIR, "CdTe_Bulk_Supercell_POSCAR")
)
- self.cdte_doped_defect_dict = loadfn( # old doped defects dict
- os.path.join(self.VASP_CDTE_DATA_DIR, "CdTe_defects_dict.json")
+ cls.cdte_doped_defect_dict = loadfn( # old doped defects dict
+ os.path.join(cls.VASP_CDTE_DATA_DIR, "CdTe_defects_dict.json")
)
- self.cdte_doped_reduced_defect_gen = loadfn(
- os.path.join(self.VASP_CDTE_DATA_DIR, "reduced_CdTe_defect_gen.json")
+ cls.cdte_doped_reduced_defect_gen = loadfn(
+ os.path.join(cls.VASP_CDTE_DATA_DIR, "reduced_CdTe_defect_gen.json")
)
- self.cdte_defects = {}
+ cls.cdte_defects = {}
# Refactor to dict of DefectEntrys objects, with doped/PyCDT names
- for defects_type, defect_dict_list in self.cdte_doped_defect_dict.items():
+ for defects_type, defect_dict_list in cls.cdte_doped_defect_dict.items():
if "bulk" not in defects_type:
for defect_dict in defect_dict_list:
- self.cdte_defects[defect_dict["name"]] = [
+ cls.cdte_defects[defect_dict["name"]] = [
input._get_defect_entry_from_defect(
defect=input.generate_defect_object(
single_defect_dict=defect_dict,
- bulk_dict=self.cdte_doped_defect_dict["bulk"],
+ bulk_dict=cls.cdte_doped_defect_dict["bulk"],
),
charge_state=charge,
)
for charge in defect_dict["charges"]
]
- self.cdte_doped_extrinsic_defects_dict = loadfn(
- os.path.join(self.VASP_CDTE_DATA_DIR, "CdTe_extrinsic_defects_dict.json")
+ cls.cdte_doped_extrinsic_defects_dict = loadfn(
+ os.path.join(cls.VASP_CDTE_DATA_DIR, "CdTe_extrinsic_defects_dict.json")
)
# Refactor to dict of DefectEntrys objects, with doped/PyCDT names
- self.cdte_extrinsic_defects = {}
+ cls.cdte_extrinsic_defects = {}
for (
defects_type,
defect_dict_list,
- ) in self.cdte_doped_extrinsic_defects_dict.items():
+ ) in cls.cdte_doped_extrinsic_defects_dict.items():
if "bulk" not in defects_type:
for defect_dict in defect_dict_list:
- self.cdte_extrinsic_defects[defect_dict["name"]] = [
+ cls.cdte_extrinsic_defects[defect_dict["name"]] = [
input._get_defect_entry_from_defect(
defect=input.generate_defect_object(
single_defect_dict=defect_dict,
- bulk_dict=self.cdte_doped_defect_dict["bulk"],
+ bulk_dict=cls.cdte_doped_defect_dict["bulk"],
),
charge_state=charge,
)
@@ -162,97 +163,97 @@ def setUp(self):
# Refactor doped defect dict to list of list of DefectEntrys() objects
# (there's a DefectEntry for each charge state)
- self.cdte_defect_list = sum(list(self.cdte_defects.values()), [])
- self.CdTe_extrinsic_defect_list = sum(
- list(self.cdte_extrinsic_defects.values()), []
+ cls.cdte_defect_list = sum(list(cls.cdte_defects.values()), [])
+ cls.CdTe_extrinsic_defect_list = sum(
+ list(cls.cdte_extrinsic_defects.values()), []
)
- self.V_Cd_dict = self.cdte_doped_defect_dict["vacancies"][0]
- self.Int_Cd_2_dict = self.cdte_doped_defect_dict["interstitials"][1]
+ cls.V_Cd_dict = cls.cdte_doped_defect_dict["vacancies"][0]
+ cls.Int_Cd_2_dict = cls.cdte_doped_defect_dict["interstitials"][1]
# Refactor to Defect() objects
- self.V_Cd = input.generate_defect_object(
- self.V_Cd_dict, self.cdte_doped_defect_dict["bulk"]
+ cls.V_Cd = input.generate_defect_object(
+ cls.V_Cd_dict, cls.cdte_doped_defect_dict["bulk"]
)
- self.V_Cd.user_charges = self.V_Cd_dict["charges"]
- self.V_Cd_entry = input._get_defect_entry_from_defect(
- self.V_Cd, self.V_Cd.user_charges[0]
+ cls.V_Cd.user_charges = cls.V_Cd_dict["charges"]
+ cls.V_Cd_entry = input._get_defect_entry_from_defect(
+ cls.V_Cd, cls.V_Cd.user_charges[0]
)
- self.V_Cd_entry_neutral = input._get_defect_entry_from_defect(
- self.V_Cd, 0
+ cls.V_Cd_entry_neutral = input._get_defect_entry_from_defect(
+ cls.V_Cd, 0
)
- self.V_Cd_entries = [
- input._get_defect_entry_from_defect(self.V_Cd, c)
- for c in self.V_Cd.user_charges
+ cls.V_Cd_entries = [
+ input._get_defect_entry_from_defect(cls.V_Cd, c)
+ for c in cls.V_Cd.user_charges
]
- self.Int_Cd_2 = input.generate_defect_object(
- self.Int_Cd_2_dict, self.cdte_doped_defect_dict["bulk"]
+ cls.Int_Cd_2 = input.generate_defect_object(
+ cls.Int_Cd_2_dict, cls.cdte_doped_defect_dict["bulk"]
)
- self.Int_Cd_2.user_charges = self.Int_Cd_2.user_charges
- self.Int_Cd_2_entry = input._get_defect_entry_from_defect(
- self.Int_Cd_2, self.Int_Cd_2.user_charges[0]
+ cls.Int_Cd_2.user_charges = cls.Int_Cd_2.user_charges
+ cls.Int_Cd_2_entry = input._get_defect_entry_from_defect(
+ cls.Int_Cd_2, cls.Int_Cd_2.user_charges[0]
)
- self.Int_Cd_2_entries = [
- input._get_defect_entry_from_defect(self.Int_Cd_2, c)
- for c in self.Int_Cd_2.user_charges
+ cls.Int_Cd_2_entries = [
+ input._get_defect_entry_from_defect(cls.Int_Cd_2, c)
+ for c in cls.Int_Cd_2.user_charges
]
# Setup structures and add oxidation states (as pymatgen-analysis-defects does it)
- self.V_Cd_struc = Structure.from_file(
- os.path.join(self.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_POSCAR")
+ cls.V_Cd_struc = Structure.from_file(
+ os.path.join(cls.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_POSCAR")
)
- self.V_Cd_minus0pt5_struc_rattled = Structure.from_file(
+ cls.V_Cd_minus0pt5_struc_rattled = Structure.from_file(
os.path.join(
- self.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_-50%_Distortion_Rattled_POSCAR"
+ cls.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_-50%_Distortion_Rattled_POSCAR"
)
)
- self.V_Cd_dimer_struc_0pt25_rattled = Structure.from_file(
+ cls.V_Cd_dimer_struc_0pt25_rattled = Structure.from_file(
os.path.join(
- self.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_Dimer_Rattled_0pt25_POSCAR"
+ cls.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_Dimer_Rattled_0pt25_POSCAR"
)
)
- self.V_Cd_minus0pt5_struc_0pt1_rattled = Structure.from_file(
+ cls.V_Cd_minus0pt5_struc_0pt1_rattled = Structure.from_file(
os.path.join(
- self.VASP_CDTE_DATA_DIR,
+ cls.VASP_CDTE_DATA_DIR,
"CdTe_V_Cd_-50%_Distortion_stdev0pt1_Rattled_POSCAR",
)
)
- self.V_Cd_minus0pt5_struc_kwarged = Structure.from_file(
- os.path.join(self.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_-50%_Kwarged_POSCAR")
+ cls.V_Cd_minus0pt5_struc_kwarged = Structure.from_file(
+ os.path.join(cls.VASP_CDTE_DATA_DIR, "CdTe_V_Cd_-50%_Kwarged_POSCAR")
)
- self.Int_Cd_2_struc = Structure.from_file(
- os.path.join(self.VASP_CDTE_DATA_DIR, "CdTe_Int_Cd_2_POSCAR")
+ cls.Int_Cd_2_struc = Structure.from_file(
+ os.path.join(cls.VASP_CDTE_DATA_DIR, "CdTe_Int_Cd_2_POSCAR")
)
- self.Int_Cd_2_minus0pt6_struc_rattled = Structure.from_file(
+ cls.Int_Cd_2_minus0pt6_struc_rattled = Structure.from_file(
os.path.join(
- self.VASP_CDTE_DATA_DIR, "CdTe_Int_Cd_2_-60%_Distortion_Rattled_POSCAR"
+ cls.VASP_CDTE_DATA_DIR, "CdTe_Int_Cd_2_-60%_Distortion_Rattled_POSCAR"
)
)
- self.Int_Cd_2_minus0pt6_NN_10_struc_unrattled = Structure.from_file(
+ cls.Int_Cd_2_minus0pt6_NN_10_struc_unrattled = Structure.from_file(
os.path.join(
- self.VASP_CDTE_DATA_DIR, "CdTe_Int_Cd_2_-60%_Distortion_NN_10_POSCAR"
+ cls.VASP_CDTE_DATA_DIR, "CdTe_Int_Cd_2_-60%_Distortion_NN_10_POSCAR"
)
)
# get example INCAR:
- self.V_Cd_INCAR_file = os.path.join(
- self.VASP_CDTE_DATA_DIR, "vac_1_Cd_0/default_INCAR"
+ cls.V_Cd_INCAR_file = os.path.join(
+ cls.VASP_CDTE_DATA_DIR, "vac_1_Cd_0/default_INCAR"
)
- self.V_Cd_INCAR = Incar.from_file(self.V_Cd_INCAR_file)
+ cls.V_Cd_INCAR = Incar.from_file(cls.V_Cd_INCAR_file)
# Setup distortion parameters
- self.V_Cd_distortion_parameters = {
+ cls.V_Cd_distortion_parameters = {
"unique_site": np.array([0.0, 0.0, 0.0]),
"num_distorted_neighbours": 2,
"distorted_atoms": [(33, "Te"), (42, "Te")],
}
- self.Int_Cd_2_normal_distortion_parameters = {
- "unique_site": self.Int_Cd_2_dict["unique_site"].frac_coords,
+ cls.Int_Cd_2_normal_distortion_parameters = {
+ "unique_site": cls.Int_Cd_2_dict["unique_site"].frac_coords,
"num_distorted_neighbours": 2,
"distorted_atoms": [(10 + 1, "Cd"), (22 + 1, "Cd")], # +1 because
# interstitial is added at the beginning of the structure
"defect_site_index": 1,
}
- self.Int_Cd_2_NN_10_distortion_parameters = {
- "unique_site": self.Int_Cd_2_dict["unique_site"].frac_coords,
+ cls.Int_Cd_2_NN_10_distortion_parameters = {
+ "unique_site": cls.Int_Cd_2_dict["unique_site"].frac_coords,
"num_distorted_neighbours": 10,
"distorted_atoms": [
(10 + 1, "Cd"),
@@ -275,7 +276,7 @@ def setUp(self):
# also testing that the package correctly ignores these and uses the bulk bond length of
# 2.8333... for d_min in the structure rattling functions.
- self.cdte_defect_folders_old_names = [ # but with "+" for positive charges
+ cls.cdte_defect_folders_old_names = [ # but with "+" for positive charges
"as_1_Cd_on_Te_-1",
"as_1_Cd_on_Te_-2",
"as_1_Cd_on_Te_0",
@@ -337,7 +338,7 @@ def setUp(self):
"vac_2_Te_+1",
"vac_2_Te_+2",
]
- self.new_names_old_names_CdTe = {
+ cls.new_names_old_names_CdTe = {
"v_Cd": "vac_1_Cd",
"v_Te": "vac_2_Te",
"Cd_Te": "as_1_Cd_on_Te",
@@ -349,7 +350,7 @@ def setUp(self):
"Te_i_C3v": "Int_Te_2",
"Te_i_Td_Te2.83": "Int_Te_3",
}
- self.new_full_names_old_names_CdTe = {
+ cls.new_full_names_old_names_CdTe = {
"v_Cd_Td_Te2.83": "vac_1_Cd",
"v_Te_Td_Cd2.83": "vac_2_Te",
"Cd_Te_Td_Cd2.83": "as_1_Cd_on_Te",
@@ -361,7 +362,7 @@ def setUp(self):
"Te_i_C3v_Cd2.71": "Int_Te_2",
"Te_i_Td_Te2.83": "Int_Te_3",
}
- self.cdte_defect_folders = [ # different charge states!
+ cls.cdte_defect_folders = [ # different charge states!
"Cd_Te_-2",
"Cd_Te_-1",
"Cd_Te_0",
@@ -428,14 +429,14 @@ def setUp(self):
]
# Get the current locale setting
- self.original_locale = locale.getlocale(locale.LC_CTYPE) # should be UTF-8
+ cls.original_locale = locale.getlocale(locale.LC_CTYPE) # should be UTF-8
- self.Ag_Sb_AgSbTe2_m2_defect_entry = loadfn(
- f"{self.DATA_DIR}/Ag_Sb_Cs_Te2.90_-2.json"
+ cls.Ag_Sb_AgSbTe2_m2_defect_entry = loadfn(
+ f"{cls.DATA_DIR}/Ag_Sb_Cs_Te2.90_-2.json"
)
# Generate defect entry for V_Cd in CdSeTe
defect_structure = Structure.from_file(
- os.path.join(self.VASP_DIR, "CdSeTe_v_Cd.POSCAR")
+ os.path.join(cls.VASP_DIR, "CdSeTe_v_Cd.POSCAR")
)
coords = [0.986350003237154, 0.4992578370461876, 0.9995065238765345]
bulk = defect_structure.copy()
@@ -445,7 +446,7 @@ def setUp(self):
bulk_structure=bulk,
)
# Generate a defect entry for each charge state
- self.V_Cd_in_CdSeTe_entry = input._get_defect_entry_from_defect(
+ cls.V_Cd_in_CdSeTe_entry = input._get_defect_entry_from_defect(
defect=defect, charge_state=0
)
@@ -462,8 +463,8 @@ def tearDown(self) -> None:
):
if_present_rm(i)
for fname in os.listdir("./"):
- if fname.endswith("json"): # distortion_metadata and parsed_defects_dict
- os.remove(f"./{fname}")
+ if fname.endswith("json") or fname.endswith("png"):
+ os.remove(f"./{fname}") # distortion_metadata, parsed_defects_dict, left-over plots
if_present_rm("test_path") # remove test_path if present
regen_defect_folder_names = [
@@ -734,7 +735,6 @@ def test_apply_rattle_bond_distortions_kwargs(self, mock_print):
# test all possible rattling kwargs with V_Cd
rattling_atom_indices = np.arange(0, 31) # Only rattle Cd
- vac_coords = np.array([0, 0, 0]) # Cd vacancy fractional coordinates
V_Cd_kwarg_distorted_dict = input._apply_rattle_bond_distortions(
self.V_Cd_entry,
@@ -763,6 +763,7 @@ def test_apply_rattle_bond_distortions_kwargs(self, mock_print):
)
self.assertEqual(V_Cd_kwarg_distorted_dict["num_distorted_neighbours"], 2)
self.assertEqual(V_Cd_kwarg_distorted_dict.get("defect_site_index"), None)
+ vac_coords = np.array([0, 0, 0]) # Cd vacancy fractional coordinates
np.testing.assert_array_equal(
V_Cd_kwarg_distorted_dict.get("defect_frac_coords"), vac_coords
)
@@ -861,7 +862,7 @@ def test_apply_snb_distortions_V_Cd(self):
verbose=True,
seed=42, # old default
)
- V_Cd_3_neighbours_distortion_parameters = self.V_Cd_distortion_parameters.copy()
+ V_Cd_3_neighbours_distortion_parameters = copy.deepcopy(self.V_Cd_distortion_parameters)
V_Cd_3_neighbours_distortion_parameters["num_distorted_neighbours"] = 3
V_Cd_3_neighbours_distortion_parameters["distorted_atoms"] += [(52, "Te")]
np.testing.assert_equal(
@@ -962,7 +963,6 @@ def test_apply_snb_distortions_kwargs(self, mock_print):
# test all possible rattling kwargs with V_Cd
rattling_atom_indices = np.arange(0, 31) # Only rattle Cd
- vac_coords = np.array([0, 0, 0]) # Cd vacancy fractional coordinates
V_Cd_kwarg_distorted_dict = input.apply_snb_distortions(
self.V_Cd_entry,
@@ -1068,7 +1068,6 @@ def test_create_vasp_input(self):
}
self.assertFalse(os.path.exists("vac_1_Cd_0"))
with warnings.catch_warnings(record=True) as w:
- _ignore_pmg_warnings()
input._create_vasp_input(
"vac_1_Cd_0",
distorted_defect_dict=V_Cd_charged_defect_dict,
@@ -1077,12 +1076,11 @@ def test_create_vasp_input(self):
"vac_1_Cd_0/Bond_Distortion_-50.0%"
)
kpoints = Kpoints.from_file("vac_1_Cd_0/Bond_Distortion_-50.0%/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert filecmp.cmp(
- "vac_1_Cd_0/Bond_Distortion_-50.0%/INCAR", self.V_Cd_INCAR_file
- )
+ generated_INCAR = Incar.from_file("vac_1_Cd_0/Bond_Distortion_-50.0%/INCAR")
+ assert generated_INCAR == self.V_Cd_INCAR
# check if POTCARs have been written:
potcar = Potcar.from_file("vac_1_Cd_0/Bond_Distortion_-50.0%/POTCAR")
@@ -1125,12 +1123,9 @@ def test_create_vasp_input(self):
V_Cd_kwarg_folder = "vac_1_Cdb_0/Bond_Distortion_-50.0%"
V_Cd_POSCAR = self._check_V_Cd_rattled_poscar(V_Cd_kwarg_folder)
kpoints = Kpoints.from_file(f"{V_Cd_kwarg_folder}/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp( # INCAR settings changed now
- f"{V_Cd_kwarg_folder}/INCAR", self.V_Cd_INCAR_file
- )
assert self.V_Cd_INCAR != Incar.from_file(f"{V_Cd_kwarg_folder}/INCAR")
kwarged_INCAR = self.V_Cd_INCAR.copy()
kwarged_INCAR.update(kwarg_incar_settings)
@@ -1161,13 +1156,9 @@ def test_create_vasp_input(self):
kpoints = Kpoints.from_file(
"test_path/vac_1_Cd_0/Bond_Distortion_-50.0%/KPOINTS"
)
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp( # INCAR settings changed now
- "test_path/vac_1_Cd_0/Bond_Distortion_-50.0%/INCAR",
- self.V_Cd_INCAR_file,
- )
assert self.V_Cd_INCAR != Incar.from_file(
"test_path/vac_1_Cd_0/Bond_Distortion_-50.0%/INCAR"
)
@@ -1649,12 +1640,9 @@ def test_write_vasp_files(self):
) # default
self.assertEqual(V_Cd_POSCAR.structure, self.V_Cd_minus0pt5_struc_rattled)
kpoints = Kpoints.from_file(f"{V_Cd_Bond_Distortion_folder}/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp( # INCAR settings changed now
- f"{V_Cd_Bond_Distortion_folder}/INCAR", self.V_Cd_INCAR_file
- )
assert self.V_Cd_INCAR != Incar.from_file(
f"{V_Cd_Bond_Distortion_folder}/INCAR"
)
@@ -1689,16 +1677,15 @@ def test_write_vasp_files(self):
Int_Cd_2_POSCAR.structure, self.Int_Cd_2_minus0pt6_struc_rattled
)
kpoints = Kpoints.from_file(f"{Int_Cd_2_Bond_Distortion_folder}/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp( # INCAR settings changed now
- f"{Int_Cd_2_Bond_Distortion_folder}/INCAR", self.V_Cd_INCAR_file
- )
kwarged_INCAR = self.V_Cd_INCAR.copy()
+ Int_Cd_2_INCAR = Incar.from_file(f"{Int_Cd_2_Bond_Distortion_folder}/INCAR")
+ assert kwarged_INCAR != Int_Cd_2_INCAR
+
kwarged_INCAR.update({"ENCUT": 212, "IBRION": 0, "EDIFF": 1e-4})
kwarged_INCAR.pop("NELECT") # different NELECT for Cd_i_+2
- Int_Cd_2_INCAR = Incar.from_file(f"{Int_Cd_2_Bond_Distortion_folder}/INCAR")
Int_Cd_2_INCAR.pop("NELECT")
assert kwarged_INCAR == Int_Cd_2_INCAR
@@ -2004,14 +1991,13 @@ def test_write_vasp_files(self):
len(_int_Cd_2_POSCAR.site_symbols), len(set(_int_Cd_2_POSCAR.site_symbols))
) # no duplicates
kpoints = Kpoints.from_file("Int_Cd_2_+1/Unperturbed/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp( # INCAR settings changed now
- "Int_Cd_2_+1/Unperturbed/INCAR", self.V_Cd_INCAR_file
- )
int_Cd_2_INCAR = Incar.from_file("Int_Cd_2_+1/Unperturbed/INCAR")
v_Cd_INCAR = self.V_Cd_INCAR.copy()
+ assert v_Cd_INCAR != int_Cd_2_INCAR
+
v_Cd_INCAR.pop("NELECT") # NELECT and NUPDOWN differs for the two defects
v_Cd_INCAR.pop("NUPDOWN")
int_Cd_2_INCAR.pop("NELECT")
@@ -2127,15 +2113,14 @@ def test_write_vasp_files_dimer_distortion(self):
d_min=d_min,
)
with patch("builtins.print") as mock_print:
- with warnings.catch_warnings(record=True) as w:
- dist.write_vasp_files()
- # check expected info printing:
- mock_print.assert_any_call(
- "Applying ShakeNBreak...",
- "Will apply the following bond distortions:",
- "['Dimer'].",
- "Then, will rattle with a std dev of 0.25 Å \n",
- )
+ dist.write_vasp_files()
+ # check expected info printing:
+ mock_print.assert_any_call(
+ "Applying ShakeNBreak...",
+ "Will apply the following bond distortions:",
+ "['Dimer'].",
+ "Then, will rattle with a std dev of 0.25 Å \n",
+ )
V_Cd_dimer_POSCAR = Structure.from_file(
"v_Cd_Td_Te2.83_0/Dimer/POSCAR"
)
@@ -2219,12 +2204,9 @@ def test_write_vasp_files_from_doped_defect_gen(self):
"-50.0% N(Distort)=2 ~[0.5,0.5,0.5]", # closest to middle
) # default
kpoints = Kpoints.from_file(f"{V_Cd_Bond_Distortion_folder}/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp( # INCAR settings changed now
- f"{V_Cd_Bond_Distortion_folder}/INCAR", self.V_Cd_INCAR_file
- )
assert self.V_Cd_INCAR != Incar.from_file(
f"{V_Cd_Bond_Distortion_folder}/INCAR"
)
@@ -2256,16 +2238,15 @@ def test_write_vasp_files_from_doped_defect_gen(self):
"-60.0% N(Distort)=2 ~[0.3,0.4,0.4]", # closest to middle
)
kpoints = Kpoints.from_file(f"{Int_Cd_2_Bond_Distortion_folder}/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
- assert not filecmp.cmp( # INCAR settings changed now
- f"{Int_Cd_2_Bond_Distortion_folder}/INCAR", self.V_Cd_INCAR_file
- )
kwarged_INCAR = self.V_Cd_INCAR.copy()
+ Int_Cd_2_INCAR = Incar.from_file(f"{Int_Cd_2_Bond_Distortion_folder}/INCAR")
+ assert kwarged_INCAR != Int_Cd_2_INCAR
+
kwarged_INCAR.update({"IVDW": 12})
kwarged_INCAR.pop("NELECT") # different NELECT for Cd_i_+2
- Int_Cd_2_INCAR = Incar.from_file(f"{Int_Cd_2_Bond_Distortion_folder}/INCAR")
Int_Cd_2_INCAR.pop("NELECT")
assert kwarged_INCAR == Int_Cd_2_INCAR
@@ -2331,7 +2312,7 @@ def _check_agsbte2_files(self, folder_name, mock_print, w, charge_state=-2):
) # no duplicates
kpoints = Kpoints.from_file(f"{folder_name}/{distortion_folder}/KPOINTS")
- self.assertEqual(kpoints.kpts, [[1, 1, 1]])
+ self.assertEqual(kpoints.kpts, [(1, 1, 1)])
if _potcars_available():
# check if POTCARs have been written:
@@ -3668,7 +3649,7 @@ def test_from_structures(self):
bulk=self.cdte_doped_defect_dict["bulk"]["supercell"]["structure"],
)
# mock_print.assert_any_call(
- # "Defect charge states will be set to the range: 0 – {Defect "
+ # "Defect charge states will be set to the range: 0 - {Defect "
# "oxidation state}, with a `padding = 1` on either side of this "
# "range."
# )
diff --git a/tests/test_plotting.py b/tests/test_plotting.py
index 10ff300..74812f3 100644
--- a/tests/test_plotting.py
+++ b/tests/test_plotting.py
@@ -1,6 +1,7 @@
import datetime
import os
import shutil
+from functools import wraps
import unittest
import warnings
from collections import OrderedDict
@@ -29,7 +30,28 @@ def if_present_rm(path):
_file_path = os.path.dirname(__file__)
_DATA_DIR = os.path.join(_file_path, "data")
+BASELINE_DIR = os.path.join(_DATA_DIR, "remote_baseline_plots")
+STYLE = os.path.join(_file_path, "../shakenbreak/shakenbreak.mplstyle")
+def custom_mpl_image_compare(filename):
+ """
+ Set our default settings for MPL image compare.
+ """
+
+ def decorator(func):
+ @wraps(func)
+ @pytest.mark.mpl_image_compare(
+ baseline_dir=BASELINE_DIR,
+ filename=filename,
+ style=STYLE,
+ savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
+ )
+ def wrapper(*args, **kwargs):
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
class PlottingDefectsTestCase(unittest.TestCase):
def setUp(self):
@@ -583,12 +605,7 @@ def _remove_current_and_saved_plots(self, defect_name, current_datetime, current
if_present_rm(f"./{defect_name}_{current_datetime}.png")
if_present_rm(f"./{defect_name}_{current_datetime_minus1min}.png")
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_max_dist.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_max_dist.png")
def test_plot_colorbar_max_distance(self):
"""Test plot_colorbar() function with metric=max_dist (default)"""
return plotting.plot_colorbar(
@@ -599,12 +616,7 @@ def test_plot_colorbar_max_distance(self):
neighbour_atom="Te",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_unparsed_disp.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_unparsed_disp.png")
def test_plot_colorbar_unparsed_disp(self):
"""Test plot_colorbar() function with disp_dict values equal to None/'Not converged'"""
disp_dict = deepcopy(self.V_Cd_displacement_dict)
@@ -619,12 +631,7 @@ def test_plot_colorbar_unparsed_disp(self):
neighbour_atom="Te",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_fake_defect_name.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_fake_defect_name.png")
def test_plot_colorbar_fake_defect_name(self):
"""Test plot_colorbar() function with wrong defect name"""
return plotting.plot_colorbar(
@@ -635,12 +642,7 @@ def test_plot_colorbar_fake_defect_name(self):
neighbour_atom="Te",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_displacement.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_displacement.png")
def test_plot_colorbar_displacement(self):
"""Test plot_colorbar() function with metric=disp and num_nearest_neighbours=None"""
return plotting.plot_colorbar(
@@ -652,12 +654,7 @@ def test_plot_colorbar_displacement(self):
metric="disp",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="Cd_Te_s32c_2_displacement.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("Cd_Te_s32c_2_displacement.png")
def test_plot_colorbar_SnB_naming_w_site_num(self):
"""Test plot_colorbar() function with SnB defect naming and
`include_site_info_in_name=True`"""
@@ -671,12 +668,7 @@ def test_plot_colorbar_SnB_naming_w_site_num(self):
metric="disp",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_maxdist_title_linecolor_label.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_maxdist_title_linecolor_label.png")
def test_plot_colorbar_legend_label_linecolor_title_saveplot(self):
"""Test plot_colorbar() function with several keyword arguments:
line_color, title, y_label, save_format, legend_label and neighbour_atom=None"""
@@ -695,12 +687,7 @@ def test_plot_colorbar_legend_label_linecolor_title_saveplot(self):
self.assertTrue(os.path.exists(os.path.join(os.getcwd(), "vac_1_Cd_0.png")))
return fig
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="Int_Se_1_6.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("Int_Se_1_6.png")
def test_plot_colorbar_with_rattled_and_imported(self):
"""Test plot_colorbar() function with both rattled and imported charge states"""
energies_dict = OrderedDict(
@@ -742,12 +729,7 @@ def test_plot_colorbar_with_rattled_and_imported(self):
self.assertTrue(os.path.exists(os.path.join(os.getcwd(), "Int_Se_1_6.png")))
return fig
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="as_2_O_on_I_1.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("as_2_O_on_I_1.png")
def test_plot_with_multiple_imported_distortions_from_same_charge_state(self):
"""Test plot_datasets() where there are multiple imported distortions from the same charge state"""
energies_dict = {
@@ -789,12 +771,7 @@ def test_plot_with_multiple_imported_distortions_from_same_charge_state(self):
self.assertTrue(os.path.exists(os.path.join(os.getcwd(), "as_2_O_on_I_1.png")))
return fig
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_colors.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_colors.png")
def test_plot_datasets_keywords(self):
"""Test plot_datasets() function testing several keywords:
colors, save_format, title, defect_species, title, neighbour_atom, num_nearest_neighbours,
@@ -818,12 +795,7 @@ def test_plot_datasets_keywords(self):
)
return fig
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_notitle.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_notitle.png")
def test_plot_datasets_without_saving(self):
"""Test plot_datasets() function testing several keywords:
title = None, num_nearest_neighbours = None, neighbour_atom = None, save_plot = False
@@ -843,12 +815,7 @@ def test_plot_datasets_without_saving(self):
self.assertFalse(os.path.exists(os.path.join(os.getcwd(), "vac_1_Cd_0.png")))
return fig
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_not_enough_markers.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_not_enough_markers.png")
def test_plot_datasets_not_enough_markers(self):
"""Test plot_datasets() function when user does not provide enough markers and linestyles"""
return plotting.plot_datasets(
@@ -864,12 +831,7 @@ def test_plot_datasets_not_enough_markers(self):
],
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_other_charge_states.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_other_charge_states.png")
def test_plot_datasets_from_other_charge_states(self):
"""Test plot_datasets() function when energy lowering distortions from other
charge states have been tried"""
@@ -883,12 +845,7 @@ def test_plot_datasets_from_other_charge_states(self):
neighbour_atom="Te",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_-2_only_rattled.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_-2_only_rattled.png")
def test_plot_datasets_only_rattled(self):
"""Test plot_datasets() function when the only distortion is 'Rattled'"""
return plotting.plot_datasets(
@@ -899,12 +856,7 @@ def test_plot_datasets_only_rattled(self):
defect_species="vac_1_Cd_0",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_-2_rattled_other_charge_states.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_-2_rattled_other_charge_states.png")
def test_plot_datasets_rattled_and_dist_from_other_chargestates(self):
"""Test plot_datasets() function when the distortion is "Rattled"
and distortions from other charge states have been tried"""
@@ -916,12 +868,7 @@ def test_plot_datasets_rattled_and_dist_from_other_chargestates(self):
defect_species="vac_1_Cd_0",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_-2_only_rattled_and_rattled_dist_from_other_charges_states.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_-2_only_rattled_and_rattled_dist_from_other_charges_states.png")
def test_plot_datasets_only_rattled_and_rattled_dist_from_other_charge_states(self):
"""Test plot_datasets() function when one of the energy lowering distortions from other
charge states is Rattled (i.e. `Rattled_from_0`)"""
@@ -969,12 +916,7 @@ def test_plot_defect_fake_output_directories(self):
)
os.remove(f"{self.VASP_CDTE_DATA_DIR}/fake_defect.png")
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="v_Ca_s0_0_plot_defect_without_colorbar.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("v_Ca_s0_0_plot_defect_without_colorbar.png")
def test_plot_defect_dimer(self):
defect_species = "v_Ca_s0_0"
defect_energies = analysis.get_energies(
@@ -1026,12 +968,7 @@ def test_plot_defect_missing_unperturbed_energy(self):
self.assertEqual(len(user_warnings), 1)
self.assertEqual(warning_message, str(w[-1].message))
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_plot_defect_add_colorbar_max_dist.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_plot_defect_add_colorbar_max_dist.png")
def test_plot_defect_add_colorbar(self):
"""Test plot_defect() function when add_colorbar = True"""
return plotting.plot_defect(
@@ -1043,12 +980,7 @@ def test_plot_defect_add_colorbar(self):
neighbour_atom="Te",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_plot_defect_without_colorbar.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_plot_defect_without_colorbar.png")
def test_plot_defect_without_colorbar(self):
"""Test plot_defect() function when add_colorbar = False"""
return plotting.plot_defect(
@@ -1060,12 +992,7 @@ def test_plot_defect_without_colorbar(self):
neighbour_atom="Te",
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_plot_defect_with_unrecognised_name.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_plot_defect_with_unrecognised_name.png")
def test_plot_defect_unrecognised_name(self):
"""Test plot_defect() function when the name cannot be formatted (e.g. if parsing and
plotting from a renamed folder)"""
@@ -1090,12 +1017,7 @@ def test_plot_defect_unrecognised_name(self):
)
return fig
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="Te_i_Td_Te2.83_+2.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("Te_i_Td_Te2.83_+2.png")
def test_plot_defect_doped_v2(self):
"""Test plot_defect() function using doped v2 naming"""
return plotting.plot_defect(
@@ -1108,12 +1030,7 @@ def test_plot_defect_doped_v2(self):
save_plot=False,
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_include_site_info_in_name.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_include_site_info_in_name.png")
def test_plot_defect_include_site_info_in_name(self):
"""Test plot_defect() function when include_site_info_in_name = True"""
return plotting.plot_defect(
@@ -1127,12 +1044,7 @@ def test_plot_defect_include_site_info_in_name(self):
save_plot=False,
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="Te_i_Td_Te2.83_+2_include_site_info_in_name.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("Te_i_Td_Te2.83_+2_include_site_info_in_name.png")
def test_plot_defect_include_site_info_in_name_doped_v2(self):
"""Test plot_defect() function when include_site_info_in_name = True, using doped v2 naming"""
return plotting.plot_defect(
@@ -1145,12 +1057,7 @@ def test_plot_defect_include_site_info_in_name_doped_v2(self):
save_plot=False,
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_0_plot_defect_without_title_units_meV.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0_plot_defect_without_title_units_meV.png")
def test_plot_defect_without_title_units_in_meV(self):
"""Test plot_defect() function when add_title = False and units = 'meV'"""
return plotting.plot_defect(
@@ -1176,7 +1083,7 @@ def test_plot_all_defects_incorrect_output_path(self):
FileNotFoundError,
plotting.plot_all_defects,
output_path="./fake_output_path",
- defects_dict={
+ defect_charges_dict={
"vac_1_Cd": [
0,
]
@@ -1189,7 +1096,7 @@ def test_plot_all_defects_nonexistent_defect_folder(self):
with warnings.catch_warnings(record=True) as w:
plotting.plot_all_defects(
output_path=self.VASP_CDTE_DATA_DIR,
- defects_dict={
+ defect_charges_dict={
"vac_1_Cd": [
-1,
]
@@ -1201,18 +1108,13 @@ def test_plot_all_defects_nonexistent_defect_folder(self):
in str(w[-1].message)
)
- @pytest.mark.mpl_image_compare(
- baseline_dir=f"{_DATA_DIR}/remote_baseline_plots",
- filename="vac_1_Cd_-2_only_rattled.png",
- style=f"{_file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_-2_only_rattled.png")
def test_plot_defects_output(self):
"""Test output of plot_all_defects() function. Test plot still generated when
distortion_metadata.json does not contain info for a given charge state"""
fig_dict = plotting.plot_all_defects(
output_path=self.VASP_CDTE_DATA_DIR,
- defects_dict={"vac_1_Cd": [0, -2]},
+ defect_charges_dict={"vac_1_Cd": [0, -2]},
save_plot=False,
min_e_diff=0.05,
add_title=False,
@@ -1230,7 +1132,7 @@ def test_plot_all_defects_min_e_diff(self):
"""Test plot_all_defects() function with keyword min_e_diff set"""
fig_dict = plotting.plot_all_defects(
output_path=self.VASP_CDTE_DATA_DIR,
- defects_dict={"vac_1_Cd": [0, -2]},
+ defect_charges_dict={"vac_1_Cd": [0, -2]},
save_plot=False,
min_e_diff=0.15,
)
diff --git a/tests/test_shakenbreak.py b/tests/test_shakenbreak.py
index bae8891..638b0de 100644
--- a/tests/test_shakenbreak.py
+++ b/tests/test_shakenbreak.py
@@ -12,6 +12,7 @@
from shakenbreak import energy_lowering_distortions, input, io, plotting
from test_energy_lowering_distortions import assert_not_called_with
+from test_plotting import custom_mpl_image_compare
Mock.assert_not_called_with = assert_not_called_with
@@ -523,12 +524,7 @@ def parse_and_generate_defect_plot(self, defect_dir):
fig_dict = plotting.plot_all_defects(defect_charges_dict, save_format="png")
return fig_dict[defect_dir]
- @pytest.mark.mpl_image_compare(
- baseline_dir="data/remote_baseline_plots",
- filename="vac_1_Cd_-2.png",
- style=f"{file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_-2.png")
def test_plot_fake_vac_1_Cd_m2(self):
return self.parse_and_generate_defect_plot("vac_1_Cd_-2")
@@ -552,12 +548,7 @@ def test_plot_fake_vac_1_Cd_plus2_new_SnB_naming(self):
def test_plot_fake_vac_1_Cd_plus2_old_SnB_naming(self):
return self.parse_and_generate_defect_plot("vac_1_Cd_2")
- @pytest.mark.mpl_image_compare(
- baseline_dir="data/remote_baseline_plots",
- filename="vac_1_Cd_-1.png",
- style=f"{file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_-1.png")
def test_plot_fake_vac_1_Cd_m1(self):
defect_dir = "vac_1_Cd_-1"
if_present_rm(defect_dir)
@@ -597,12 +588,7 @@ def test_plot_fake_vac_1_Cd_m1(self):
fig_dict = plotting.plot_all_defects(defect_charges_dict, save_format="png")
return fig_dict["vac_1_Cd_-1"]
- @pytest.mark.mpl_image_compare(
- baseline_dir="data/remote_baseline_plots",
- filename="vac_1_Cd_0.png",
- style=f"{file_path}/../shakenbreak/shakenbreak.mplstyle",
- savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
- )
+ @custom_mpl_image_compare("vac_1_Cd_0.png")
def test_plot_fake_vac_1_Cd_0(self):
defect_dir = "vac_1_Cd_0"
if_present_rm(defect_dir)