diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..ef6ef4f15 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,8 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + day: "sunday" + time: "23:00" \ No newline at end of file diff --git a/.github/scripts/update_version.py b/.github/scripts/update_version.py index 635cf8268..a8dc61204 100644 --- a/.github/scripts/update_version.py +++ b/.github/scripts/update_version.py @@ -11,7 +11,7 @@ patch_number = line_split[2].split("'")[0].split('"')[0] # Increment patch number -patch_number = str(int(patch_number) + 1) + "'" +patch_number = str(int(patch_number) + 1) + '"\n' new_line = line_split[0] + "." + line_split[1] + "." + patch_number diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 09b2a0fba..6ed2cc1e0 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -10,5 +10,5 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: psf/black@stable \ No newline at end of file diff --git a/.github/workflows/build-flake.yml b/.github/workflows/build-flake.yml index 3393b7908..3bf56d698 100644 --- a/.github/workflows/build-flake.yml +++ b/.github/workflows/build-flake.yml @@ -18,9 +18,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install dependencies diff --git a/.github/workflows/check_install_dev.yml b/.github/workflows/check_install_dev.yml index 4e9d16f77..262f541a4 100644 --- a/.github/workflows/check_install_dev.yml +++ b/.github/workflows/check_install_dev.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] # include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest @@ -28,10 +28,10 @@ jobs: # runs-on: macos-latest # allow_failure: true steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install repo diff --git a/.github/workflows/check_install_main.yml b/.github/workflows/check_install_main.yml index a276cab17..e9ba98ab6 100644 --- a/.github/workflows/check_install_main.yml +++ b/.github/workflows/check_install_main.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest, windows-latest, macos-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] #include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest @@ -28,10 +28,10 @@ jobs: # runs-on: macos-latest # allow_failure: true steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install repo diff --git a/.github/workflows/check_install_quick.yml b/.github/workflows/check_install_quick.yml index f83ee0b73..eadd12370 100644 --- a/.github/workflows/check_install_quick.yml +++ b/.github/workflows/check_install_quick.yml @@ -20,7 +20,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.9", "3.12"] + python-version: ["3.10", "3.12"] # Currently no public runners available for this but this or arm64 should work next time # include: # - python-version: "3.10" @@ -28,10 +28,10 @@ jobs: # runs-on: macos-latest # allow_failure: true steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install repo @@ -42,4 +42,4 @@ jobs: python -c "import py4DSTEM; print(py4DSTEM.__version__)" # - name: Check machine arch # run: | - # python -c "import platform; print(platform.machine())" \ No newline at end of file + # python -c "import platform; print(platform.machine())" diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 3e8071f6f..a83e35d30 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -11,13 +11,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: # Full git history is needed to get a proper list of changed files within `super-linter` fetch-depth: 0 - name: Lint Code Base - uses: github/super-linter@v4 + uses: github/super-linter@v5 env: VALIDATE_ALL_CODEBASE: false VALIDATE_PYTHON_FLAKE8: true diff --git a/.github/workflows/pypi_upload.yml b/.github/workflows/pypi_upload.yml index 906ec5d1a..75dd5a559 100644 --- a/.github/workflows/pypi_upload.yml +++ b/.github/workflows/pypi_upload.yml @@ -6,22 +6,22 @@ on: push: branches: - main - pull_request: - branches: - - main + # pull_request: + # branches: + # - main jobs: update_version: runs-on: ubuntu-latest name: Check if version.py is changed and update if the version.py is not changed steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 token: ${{ secrets.GH_ACTION_VERSION_UPDATE }} - name: Get changed files id: changed-files-specific - uses: tj-actions/changed-files@v41 + uses: tj-actions/changed-files@v44 with: files: | py4DSTEM/version.py @@ -39,14 +39,15 @@ jobs: git config --global user.email "ben.savitzky@gmail.com" git config --global user.name "bsavitzky" git commit -a -m "Auto-update version number (GH Action)" - git push origin main + git push origin + sync_with_dev: needs: update_version runs-on: ubuntu-latest name: Sync main with dev steps: - name: Sync main with dev - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: dev fetch-depth: 0 @@ -63,11 +64,11 @@ jobs: runs-on: ubuntu-latest name: Deploy to PyPI steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: ref: dev - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.8 - name: Install dependencies @@ -83,3 +84,4 @@ jobs: password: ${{ secrets.PYPI_API_TOKEN }} + diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..70adeee4c --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +ben.savitzky@gmail.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 000000000..7588a58d0 --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,126 @@ +## py4DSTEM Contributor Guidelines + +**Welcome to the py4DSTEM project!** + +We are grateful for your interest in contributing to py4DSTEM, an open-source Python library for 4D-STEM data analysis. Your contributions will help us make py4DSTEM a more powerful and versatile tool for the scientific community. + +This document outlines the guidelines and expectations for contributors to the py4DSTEM project. Please read it carefully before making any contributions. + +### Contribution Types + +There are many ways to contribute to py4DSTEM, including: + +* **Reporting bugs:** If you encounter a bug in py4DSTEM, please file a bug report on GitHub. Be sure to provide as much detail as possible about the bug, including steps to reproduce it. + +* **Submitting feature requests:** If you have a suggestion for a new feature for py4DSTEM, please submit a feature request on GitHub. Describe the feature in detail and explain how it would benefit users. + +* **Improving documentation:** py4DSTEM's documentation is always in need of improvement. If you have suggestions for improving the documentation, please submit a pull request or open an issue on GitHub. + +* **Developing new code:** If you are a developer, you can contribute to py4DSTEM by writing new code. Please follow the coding guidelines below. + +### Coding Guidelines + +* **Code style:** py4DSTEM uses the black code formatter and flake8 linter. All code must pass these checks without error before it can be merged. We suggest using `pre-commit` to help ensure any code commited follows these practices, checkout the [setting up developer environment section below](#install). We also try to abide by PEP8 coding style guide where possible. + +* **Documentation:** All code should be well-documented, and use Numpy style docstrings. Use docstrings to document functions and classes, add comments to explain complex code both blocks and individual lines, and use informative variable names. + +* **Testing:** Ideally all new code should be accompanied by tests using pyTest framework; at the least we require examples of old and new behaviour caused by the PR. For bug fixes this can be a block of code which currently fails and works with the proposed changes. For new workflows or extensive feature additions, please also include a Jupyter notebook demonstrating the changes for an entire workflow i.e. from loading the input data to visualizing and saving any processed results. + +* **Dependencies:** New dependencies represent a significant change to the package, and any PRs which add new dependencies will require discussion and agreement from the development team. If a new dependency is required, please prioritize adding dependencies that are actively maintained, have permissive installation requirements, and are accessible through both pip and conda. + +### Pull Requests + +When submitting a pull request, please: + +* **Open a corresponding issue:** Before submitting a pull request, open an issue to discuss your changes and get feedback. + +* **Write a clear and concise pull request description:** The pull request description should clearly explain the changes you made and why they are necessary. + +* **Follow the coding guidelines:** Make sure your code follows the coding guidelines outlined above. + +* **Add tests:** If your pull request includes new code, please add unit tests. + +* **Address reviewer feedback:** Respond promptly to reviewer feedback and make changes as needed. + +### Code Review + +All pull requests will be reviewed by project maintainers before they are merged. The maintainers will provide feedback on the code and may ask for changes. Please be respectful of the maintainers' feedback and make changes as needed. + +### Code of Conduct + +py4DSTEM is committed to providing a welcoming and inclusive environment for all contributors. Please read and follow the py4DSTEM Code of Conduct. + +### Acknowledgments + +We are grateful to all of the contributors who have made py4DSTEM possible. Your contributions are invaluable to the scientific community. + +Thank you for your interest in contributing to py4DSTEM! We look forward to seeing your contributions. + + +### Setting up the Developer Environment + + +1. **Fork the Git repository:** Go to the py4DSTEM GitHub repository and click the "Fork" button. This will create your own copy of the repository on your GitHub account. + +2. **Clone the forked repository:** Open your terminal and clone your forked repository to your local machine. This will create a local copy of your fork of the repository on your computer's filesystem. Use the following command, replacing `` with your GitHub username and with the the directory where you'd like to copy the py4DSTEM repository onto your filesystem: + + ```bash + cd + git clone https://github.com//py4DSTEM.git + ``` +**_extra tips_:** Github has an excellent [tutorial](https://docs.github.com/en/get-started/quickstart/fork-a-repo) you can follow to learn more about for steps one and two. + +3. **Create a new branch:** Create a new branch for your development work. This will silo any edits you make to the new branch, allowing you to work and make edits to your working branch without affecting the main branch of the repository. Use the following command, replacing `` with a name for your branch. Please use an informative name describing the purpose of the branch, e.g. ``: + + ```bash + git checkout -b + ``` + +4. **Create an anaconda environment:** Environments are useful for a number of reasons. They create an isolated computing environment for each project or task, making it easy to reproduce results, manage dependencies, and collaborate with others. There are a number of different ways to use environments, including anaconda, pipenv, virtualenv - we recommend using [anaconda](https://docs.anaconda.com/free/anaconda/install/index.html) for environment management, where you can create a new environment with: + + ```bash + conda create -n py4dstem-dev python=3.10 + ``` + +You can then enter that environment using: + + ```bash + conda activate py4dstem-dev + ``` + + +5. **Install py4DSTEM in editable mode:** Navigate to the cloned repository directory and install py4DSTEM in editable mode. This will allow you to make changes to the code and test them without having to reinstall the library. Use the following command: + + ```bash + conda activate py4dstem-dev + pip install -e . + pip install flake8 black # install linter and autoformatter + ``` + +You can now make changes to the code and test them using your favorite Python IDE or editor. + +6. **_(Optional)_ - Using `pre-commit`**: `pre-commit` streamlines code formatting and linting. Essentially it runs black (an autoformatter) and flake8 (a linter) whenever a new commit is attempted on all staged files, and only allows the commit to proceed if they both pass. To use pre-commit, run: + + ```bash + conda activate py4dstem-dev + pip install pre-commit + cd # go to your py4DSTEM repo + pre-commit install + ``` + +This will setup pre-commit to work on this repo by creating/changing a file in .git/hooks/pre-commit, which tells `pre-commit` to automatically run flake8 and black when you try to commit code. It won't affect any other repos. + +**_extra tips_:** + +```bash +# You can call pre commit manually at any time without committing +pre-commit run # will check any staged files +pre-commit run -a # will run on all files in repo + +# you can bypass the hook and commit files without the checks +# (this isn't best practice and should be avoided, but there are times it can be useful) + +git add file # stage file as usual +git commit -m "you commit message" --no-verify # commit without running checks +git push # push to repo. +``` diff --git a/README.md b/README.md index aa102542a..3fe6cc745 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ > :warning: **py4DSTEM version 0.14 update** :warning: Warning: this is a major update and we expect some workflows to break. You can still install previous versions of py4DSTEM [as discussed here](#legacyinstall) - +> :warning: **Phase retrieval refactor version 0.14.9** :warning: Warning: The phase-retrieval modules in py4DSTEM (DPC, parallax, and ptychography) underwent a major refactor in version 0.14.9 and as such older tutorial notebooks will not work as expected. Notably, class names have been pruned to remove the trailing "Reconstruction" (`DPCReconstruction` -> `DPC` etc.), and regularization functions have dropped the `_iter` suffix (and are instead specified as boolean flags). We are working on updating the tutorial notebooks to reflect these changes. In the meantime, there's some more information in the relevant pull request [here](https://github.com/py4dstem/py4DSTEM/pull/597#issuecomment-1890325568). ![py4DSTEM logo](/images/py4DSTEM_logo.png) @@ -11,6 +11,7 @@ Additional information: - [Installation instructions](#install) - [The py4DSTEM documentation pages](https://py4dstem.readthedocs.io/en/latest/index.html). - [Tutorials and example code](https://github.com/py4dstem/py4DSTEM_tutorials) +- [Want to get involved?](#Contributing) - [Our open access py4DSTEM publication in Microscopy and Microanalysis](https://doi.org/10.1017/S1431927621000477) describing this project and demonstrating a variety of applications. - [Our open access 4D-STEM review in Microscopy and Microanalysis](https://doi.org/10.1017/S1431927619000497) describing this project and demonstrating a variety of applications. @@ -181,16 +182,18 @@ Links to the data used in each notebook are provided in the intro cell of each n # More information + -## Documentation +## Contributing Guide +We are grateful for your interest in contributing to py4DSTEM. There are many ways to contribute to py4DSTEM, including Reporting bugs, Submitting feature requests, Improving documentation and Developing new code -Our documentation pages are [available here](https://py4dstem.readthedocs.io/en/latest/index.html). +For more information checkout our [Contributors Guide](CONTRIBUTORS.md) +## Documentation -## For contributors +Our documentation pages are [available here](https://py4dstem.readthedocs.io/en/latest/index.html). -Please see [here](https://gist.github.com/bsavitzky/8b1ee4c1244814940e7cff4500535dba). ## Scientific papers which use **py4DSTEM** diff --git a/docs/papers.md b/docs/papers.md index 4b1bc6b8d..4539dd943 100644 --- a/docs/papers.md +++ b/docs/papers.md @@ -5,57 +5,76 @@ Please email clophus@lbl.gov if you have used py4DSTEM for analysis and your paper is not listed below! -### 2022 (9) +### 2023 (0) -[Correlative image learning of chemo-mechanics in phase-transforming solids](https://www.nature.com/articles/s41563-021-01191-0), Nature Materials (2022) -[Correlative analysis of structure and chemistry of LixFePO4 platelets using 4D-STEM and X-ray ptychography](https://doi.org/10.1016/j.mattod.2021.10.031), Materials Today 52, 102 (2022). -[Visualizing Grain Statistics in MOCVD WSe2 through Four-Dimensional Scanning Transmission Electron Microscopy](https://doi.org/10.1021/acs.nanolett.1c04315), Nano Letters 22, 2578 (2022). +### 2022 (16) -[Electric field control of chirality](https://doi.org/10.1126/sciadv.abj8030), Science Advances 8 (2022). -[Real-Time Interactive 4D-STEM Phase-Contrast Imaging From Electron Event Representation Data: Less computation with the right representation](https://doi.org/10.1109/MSP.2021.3120981), IEEE Signal Processing Magazine 39, 25 (2022). +[Disentangling multiple scattering with deep learning: application to strain mapping from electron diffraction patterns](https://doi.org/10.1038/s41524-022-00939-9), J Munshi*, A Rakowski*, et al., npj Computational Materials 8, 254 (2022) -[Microstructural dependence of defect formation in iron-oxide thin films](https://doi.org/10.1016/j.apsusc.2022.152844), Applied Surface Science 589, 152844 (2022). +[Flexible CO2 Sensor Architecture with Selective Nitrogen Functionalities by One-Step Laser-Induced Conversion of Versatile Organic Ink](https://doi.org/10.1002/adfm.202207406), H Wang et al., Advanced Functional Materials 32, 2207406 (2022) -[Chemical and Structural Alterations in the Amorphous Structure of Obsidian due to Nanolites](https://doi.org/10.1017/S1431927621013957), Microscopy and Microanalysis 28, 289 (2022). +[Defect Contrast with 4D-STEM: Understanding Crystalline Order with Virtual Detectors and Beam Modification](https://doi.org/10.1093/micmic/ozad045) SM Ribet et al., Microscopy and Microanalysis 29, 1087 (2023). -[Nanoscale characterization of crystalline and amorphous phases in silicon oxycarbide ceramics using 4D-STEM](https://doi.org/10.1016/j.matchar.2021.111512), Materials Characterization 181, 111512 (2021). +[Structural heterogeneity in non-crystalline TexSe1−x thin films](https://doi.org/10.1063/5.0094600), B Sari et al., Applied Physics Letters 121, 012101 (2022) -[Disentangling multiple scattering with deep learning: application to strain mapping from electron diffraction patterns](https://arxiv.org/abs/2202.00204), arXiv:2202.00204 (2022). +[Cryogenic 4D-STEM analysis of an amorphouscrystalline polymer blend: Combined nanocrystalline and amorphous phase mapping](https://doi.org/10.1016/j.isci.2022.103882), J Donohue et al., iScience 25, 103882 (2022) + +[Hydrogen-assisted decohesion associated with nanosized grain boundary κ-carbides in a high-Mn lightweight steel](https://doi.org/10.1016/j.actamat.2022.118392), MN Elkot et al., Acta Materialia + 241, 118392 (2022) + +[4D-STEM Ptychography for Electron-Beam-Sensitive Materials](https://doi.org/10.1021/acscentsci.2c01137), G Li et al., ACS Central Science 8, 1579 (2022) + +[Developing a Chemical and Structural Understanding of the Surface Oxide in a Niobium Superconducting Qubit](https://doi.org/10.1021/acsnano.2c07913), AA Murthy et al., ACS Nano 16, 17257 (2022) + +[Correlative image learning of chemo-mechanics in phase-transforming solids](https://www.nature.com/articles/s41563-021-01191-0), HD Deng et al., Nature Materials (2022) + +[Correlative analysis of structure and chemistry of LixFePO4 platelets using 4D-STEM and X-ray ptychography](https://doi.org/10.1016/j.mattod.2021.10.031), LA Hughes*, BH Savitzky, et al., Materials Today 52, 102 (2022) + +[Visualizing Grain Statistics in MOCVD WSe2 through Four-Dimensional Scanning Transmission Electron Microscopy](https://doi.org/10.1021/acs.nanolett.1c04315), A Londoño-Calderon et al., Nano Letters 22, 2578 (2022) + +[Electric field control of chirality](https://doi.org/10.1126/sciadv.abj8030), P Behera et al., Science Advances 8 (2022) + +[Real-Time Interactive 4D-STEM Phase-Contrast Imaging From Electron Event Representation Data: Less computation with the right representation](https://doi.org/10.1109/MSP.2021.3120981), P Pelz et al., IEEE Signal Processing Magazine 39, 25 (2022) + +[Microstructural dependence of defect formation in iron-oxide thin films](https://doi.org/10.1016/j.apsusc.2022.152844), BK Derby et al., Applied Surface Science 589, 152844 (2022) + +[Chemical and Structural Alterations in the Amorphous Structure of Obsidian due to Nanolites](https://doi.org/10.1017/S1431927621013957), E Kennedy et al., Microscopy and Microanalysis 28, 289 (2022) + +[Nanoscale characterization of crystalline and amorphous phases in silicon oxycarbide ceramics using 4D-STEM](https://doi.org/10.1016/j.matchar.2021.111512), Ni Yang et al., Materials Characterization 181, 111512 (2021) ### 2021 (10) -[Cryoforged nanotwinned titanium with ultrahigh strength and ductility](https://doi.org/10.1126/science.abe7252), Science 16, 373, 1363 (2021). +[Cryoforged nanotwinned titanium with ultrahigh strength and ductility](https://doi.org/10.1126/science.abe7252), Science 16, 373, 1363 (2021) -[Strain fields in twisted bilayer graphene](https://doi.org/10.1038/s41563-021-00973-w), Nature Materials 20, 956 (2021). +[Strain fields in twisted bilayer graphene](https://doi.org/10.1038/s41563-021-00973-w), Nature Materials 20, 956 (2021) -[Determination of Grain-Boundary Structure and Electrostatic Characteristics in a SrTiO3 Bicrystal by Four-Dimensional Electron Microscopy](https://doi.org/10.1021/acs.nanolett.1c02960), Nanoletters 21, 9138 (2021). +[Determination of Grain-Boundary Structure and Electrostatic Characteristics in a SrTiO3 Bicrystal by Four-Dimensional Electron Microscopy](https://doi.org/10.1021/acs.nanolett.1c02960), Nanoletters 21, 9138 (2021) [Local Lattice Deformation of Tellurene Grain Boundaries by Four-Dimensional Electron Microscopy](https://pubs.acs.org/doi/10.1021/acs.jpcc.1c00308), Journal of Physical Chemistry C 125, 3396 (2021). -[Extreme mixing in nanoscale transition metal alloys](https://doi.org/10.1016/j.matt.2021.04.014), Matter 4, 2340 (2021). +[Extreme mixing in nanoscale transition metal alloys](https://doi.org/10.1016/j.matt.2021.04.014), Matter 4, 2340 (2021) -[Multibeam Electron Diffraction](https://doi.org/10.1017/S1431927620024770), Microscopy and Microanalysis 27, 129 (2021). +[Multibeam Electron Diffraction](https://doi.org/10.1017/S1431927620024770), Microscopy and Microanalysis 27, 129 (2021) -[A Fast Algorithm for Scanning Transmission Electron Microscopy Imaging and 4D-STEM Diffraction Simulations](https://doi.org/10.1017/S1431927621012083), Microscopy and Microanalysis 27, 835 (2021). +[A Fast Algorithm for Scanning Transmission Electron Microscopy Imaging and 4D-STEM Diffraction Simulations](https://doi.org/10.1017/S1431927621012083), Microscopy and Microanalysis 27, 835 (2021) -[Fast Grain Mapping with Sub-Nanometer Resolution Using 4D-STEM with Grain Classification by Principal Component Analysis and Non-Negative Matrix Factorization](https://doi.org/10.1017/S1431927621011946), Microscopy and Microanalysis 27, 794 +[Fast Grain Mapping with Sub-Nanometer Resolution Using 4D-STEM with Grain Classification by Principal Component Analysis and Non-Negative Matrix Factorization](https://doi.org/10.1017/S1431927621011946), Microscopy and Microanalysis 27, 794 (2021) -[Prismatic 2.0 – Simulation software for scanning and high resolution transmission electron microscopy (STEM and HRTEM)](https://doi.org/10.1016/j.micron.2021.103141), Micron 151, 103141 (2021). +[Prismatic 2.0 – Simulation software for scanning and high resolution transmission electron microscopy (STEM and HRTEM)](https://doi.org/10.1016/j.micron.2021.103141), Micron 151, 103141 (2021) -[4D-STEM of Beam-Sensitive Materials](https://doi.org/10.1021/acs.accounts.1c00073), Accounts of Chemical Research 54, 2543 (2021). +[4D-STEM of Beam-Sensitive Materials](https://doi.org/10.1021/acs.accounts.1c00073), Accounts of Chemical Research 54, 2543 (2021) ### 2020 (3) -[Patterned probes for high precision 4D-STEM bragg measurements](https://doi.org/10.1063/5.0015532), Ultramicroscopy 209, 112890 (2020). - +[Patterned probes for high precision 4D-STEM bragg measurements](https://doi.org/10.1063/5.0015532), Ultramicroscopy 209, 112890 (2020) -[Tilted fluctuation electron microscopy](https://doi.org/10.1063/5.0015532), Applied Physics Letters 117, 091903 (2020). +[Tilted fluctuation electron microscopy](https://doi.org/10.1063/5.0015532), Applied Physics Letters 117, 091903 (2020) [4D-STEM elastic stress state characterisation of a TWIP steel nanotwin](https://arxiv.org/abs/2004.03982), arXiv:2004.03982 diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 70a36dec1..30ead79a6 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -1,5 +1,5 @@ # BraggVectors methods - +from __future__ import annotations import inspect from warnings import warn @@ -813,6 +813,56 @@ def to_strainmap(self, name: str = None): return StrainMap(self, name) if name else StrainMap(self) + def plot( + self, + index: tuple[int, int] | list[int], + cal: str = "cal", + returnfig: bool = False, + **kwargs, + ): + """ + Plot Bragg vector, from a specified index. + Calls py4DSTEM.process.diffraction.plot_diffraction_pattern(braggvectors.[index], **kwargs). + Optionally can return the figure. + + Parameters + ---------- + index : tuple[int,int] | list[int] + scan position for which Bragg vectors to plot + cal : str, optional + Choice to plot calibrated or raw Bragg vectors must be 'raw' or 'cal', by default 'cal' + returnfig : bool, optional + Boolean to return figure or not, by default False + + Returns + ------- + tuple (figure, axes) + matplotlib figure, axes returned if `returnfig` is True + """ + cal = cal.lower() + assert cal in ( + "cal", + "raw", + ), f"'cal' must be in ('cal', 'raw') {cal = } passed" + from py4DSTEM.process.diffraction import plot_diffraction_pattern + + if cal == "cal": + pl = self.cal[index] + else: + pl = self.raw[index] + + if returnfig: + return plot_diffraction_pattern( + pl, + returnfig=returnfig, + **kwargs, + ) + else: + plot_diffraction_pattern( + pl, + **kwargs, + ) + ######### END BraggVectorMethods CLASS ######## diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index e81eeb62f..6ccb46d1f 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -336,7 +336,7 @@ class BVects: """ Enables - >>> v.qx,v.qy,v.I + >>> v.qx,v.qy,v.I, optionally, v.h,v.k,v.l -like access to a collection of Bragg vector. """ @@ -361,13 +361,61 @@ def I(self): def data(self): return self._data + @property + def h(self): + try: + return self._data["h"] + except ValueError: + raise AttributeError("h indicies not set") + + @property + def k(self): + try: + return self._data["k"] + except ValueError: + raise AttributeError("k indicies not set") + + @property + def l(self): + try: + return self._data["l"] + except ValueError: + raise AttributeError("l indicies not set") + def __repr__(self): space = " " * len(self.__class__.__name__) + " " string = f"{self.__class__.__name__}( " string += f"A set of {len(self.data)} bragg vectors." - string += " Access data with .qx, .qy, .I, or .data.)" + string += " Access data with .qx, .qy, .I, or .data. Optionally .h, .k, and.l, if indexed)" return string + def plot( + self, + returnfig: bool = False, + **kwargs, + ): + """ + Plot the diffraction pattern. + Calls `py4DSTEM.process.diffraction.plot_diffraction_pattern` and passes kwargs to it. + + Parameters + ---------- + returnfig: bool + If True the figure is returned, else its ploted but not returned. Defaults to False + + Returns + ------- + figure : matplotlib object + If `returnfig` is True, the figure is returned. + + """ + from py4DSTEM.process.diffraction import plot_diffraction_pattern + + if returnfig: + return plot_diffraction_pattern(self, returnfig=returnfig, **kwargs) + else: + plot_diffraction_pattern(self, **kwargs) + class RawVectorGetter: def __init__( diff --git a/py4DSTEM/braggvectors/diskdetection_aiml.py b/py4DSTEM/braggvectors/diskdetection_aiml.py index 4d23ebf6c..8e3607edb 100644 --- a/py4DSTEM/braggvectors/diskdetection_aiml.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml.py @@ -8,6 +8,8 @@ import json import shutil import numpy as np +from pathlib import Path + from scipy.ndimage import gaussian_filter from time import time @@ -437,9 +439,9 @@ def find_Bragg_disks_aiml_serial( raise ImportError("Import Error: Please install crystal4D before proceeding") # Make the peaks PointListArray - # dtype = [('qx',float),('qy',float),('intensity',float)] - peaks = BraggVectors(datacube.Rshape, datacube.Qshape) - + dtype = [("qx", float), ("qy", float), ("intensity", float)] + # peaks = BraggVectors(datacube.Rshape, datacube.Qshape) + peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny)) # check that the filtered DP is the right size for the probe kernel: if filter_function: assert callable(filter_function), "filter_function must be callable" @@ -518,7 +520,7 @@ def find_Bragg_disks_aiml_serial( subpixel=subpixel, upsample_factor=upsample_factor, filter_function=filter_function, - peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry), + peaks=peaks.get_pointlist(Rx, Ry), model_path=model_path, ) t2 = time() - t0 @@ -884,7 +886,7 @@ def _get_latest_model(model_path=None): + "https://www.tensorflow.org/install" + "for more information" ) - from py4DSTEM.io.google_drive_downloader import download_file_from_google_drive + from py4DSTEM.io.google_drive_downloader import gdrive_download tf.keras.backend.clear_session() @@ -894,7 +896,12 @@ def _get_latest_model(model_path=None): except: pass # download the json file with the meta data - download_file_from_google_drive("FCU-Net", "./tmp/model_metadata.json") + gdrive_download( + "FCU-Net", + destination="./tmp/", + filename="model_metadata.json", + overwrite=True, + ) with open("./tmp/model_metadata.json") as f: metadata = json.load(f) file_id = metadata["file_id"] @@ -918,7 +925,8 @@ def _get_latest_model(model_path=None): else: print("Checking the latest model on the cloud... \n") filename = file_path + file_type - download_file_from_google_drive(file_id, filename) + filename = Path(filename) + gdrive_download(file_id, destination="./tmp", filename=filename.name) try: shutil.unpack_archive(filename, "./tmp", format="zip") except: diff --git a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py index ffa6e891b..bd2736719 100644 --- a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py @@ -124,8 +124,9 @@ def find_Bragg_disks_aiml_CUDA( """ # Make the peaks PointListArray - # dtype = [('qx',float),('qy',float),('intensity',float)] - peaks = BraggVectors(datacube.Rshape, datacube.Qshape) + dtype = [("qx", float), ("qy", float), ("intensity", float)] + # peaks = BraggVectors(datacube.Rshape, datacube.Qshape) + peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny)) # check that the filtered DP is the right size for the probe kernel: if filter_function: @@ -221,7 +222,7 @@ def find_Bragg_disks_aiml_CUDA( subpixel=subpixel, upsample_factor=upsample_factor, filter_function=filter_function, - peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry), + peaks=peaks.get_pointlist(Rx, Ry), get_maximal_points=get_maximal_points, blocks=blocks, threads=threads, diff --git a/py4DSTEM/braggvectors/diskdetection_cuda.py b/py4DSTEM/braggvectors/diskdetection_cuda.py index 870b8303d..ddea4d9ad 100644 --- a/py4DSTEM/braggvectors/diskdetection_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_cuda.py @@ -156,9 +156,11 @@ def find_Bragg_disks_CUDA( patt_idx = batch_idx * batch_size + subbatch_idx rx, ry = np.unravel_index(patt_idx, (datacube.R_Nx, datacube.R_Ny)) batched_subcube[subbatch_idx, :, :] = cp.array( - datacube.data[rx, ry, :, :] - if filter_function is None - else filter_function(datacube.data[rx, ry, :, :]), + ( + datacube.data[rx, ry, :, :] + if filter_function is None + else filter_function(datacube.data[rx, ry, :, :]) + ), dtype=cp.float32, ) diff --git a/py4DSTEM/datacube/datacube.py b/py4DSTEM/datacube/datacube.py index 4d87afdd5..93a48bbdd 100644 --- a/py4DSTEM/datacube/datacube.py +++ b/py4DSTEM/datacube/datacube.py @@ -405,7 +405,9 @@ def pad_Q(self, N=None, output_size=None): d = pad_data_diffraction(self, pad_factor=N, output_size=output_size) return d - def resample_Q(self, N=None, output_size=None, method="bilinear"): + def resample_Q( + self, N=None, output_size=None, method="bilinear", conserve_array_sums=False + ): """ Resamples the data in diffraction space by resampling factor N, or to match output_size, using either 'fourier' or 'bilinear' interpolation. @@ -418,7 +420,11 @@ def resample_Q(self, N=None, output_size=None, method="bilinear"): from py4DSTEM.preprocess import resample_data_diffraction d = resample_data_diffraction( - self, resampling_factor=N, output_size=output_size, method=method + self, + resampling_factor=N, + output_size=output_size, + method=method, + conserve_array_sums=conserve_array_sums, ) return d @@ -479,13 +485,29 @@ def filter_hot_pixels(self, thresh, ind_compare=1, return_mask=False): """ from py4DSTEM.preprocess import filter_hot_pixels - d = filter_hot_pixels( + datacube = filter_hot_pixels( self, thresh, ind_compare, return_mask, ) - return d + return datacube + + def median_filter_masked_pixels(self, mask, kernel_width: int = 3): + """ + This function fixes a datacube where the same pixels are consistently + bad. It requires a mask that identifies all the bad pixels in the dataset. + Then for each diffraction pattern, a median kernel is applied around each + bad pixel with the specified width. + """ + from py4DSTEM.preprocess import median_filter_masked_pixels + + datacube = median_filter_masked_pixels( + self, + mask, + kernel_width, + ) + return datacube # Probe @@ -494,7 +516,7 @@ def get_vacuum_probe( ROI=None, align=True, mask=None, - threshold=0.2, + threshold=0.0, expansion=12, opening=3, verbose=False, diff --git a/py4DSTEM/io/filereaders/read_arina.py b/py4DSTEM/io/filereaders/read_arina.py index 6f7c463d2..832499d3f 100644 --- a/py4DSTEM/io/filereaders/read_arina.py +++ b/py4DSTEM/io/filereaders/read_arina.py @@ -27,7 +27,7 @@ def read_arina( dtype_bin(float): specify datatype for bin on load if need something other than uint16 flatfield (np.ndarray): - flatfield forcorrection factors + flatfield for correction factors, converts data to float Returns: DataCube @@ -51,12 +51,15 @@ def read_arina( nimages % scan_width < 1e-6 ), "scan_width must be integer multiple of x*y size" - if dtype.type is np.uint32: + if dtype.type is np.uint32 and flatfield is None: print("Dataset is uint32 but will be converted to uint16") dtype = np.dtype(np.uint16) if dtype_bin: array_3D = np.empty((nimages, width, height), dtype=dtype_bin) + elif flatfield is not None: + array_3D = np.empty((nimages, width, height), dtype="float32") + print("Dataset is uint32 but will be converted to float32") else: array_3D = np.empty((nimages, width, height), dtype=dtype) @@ -65,9 +68,9 @@ def read_arina( if flatfield is None: correction_factors = 1 else: - # Avoid div by 0 errors -> pixel with value 0 will be set to meadian - flatfield[flatfield == 0] = 1 correction_factors = np.median(flatfield) / flatfield + # Avoid div by 0 errors -> pixel with value 0 will be set to median + correction_factors[flatfield == 0] = 1 for dset in f["entry"]["data"]: image_index = _processDataSet( diff --git a/py4DSTEM/io/google_drive_downloader.py b/py4DSTEM/io/google_drive_downloader.py index 5b53f19ae..a50fb2c09 100644 --- a/py4DSTEM/io/google_drive_downloader.py +++ b/py4DSTEM/io/google_drive_downloader.py @@ -57,7 +57,7 @@ ), "small_dm3_3Dstack": ("small_dm3_3Dstack.dm3", "1B-xX3F65JcWzAg0v7f1aVwnawPIfb5_o"), "FCU-Net": ( - "filename.name", + "model_metadata.json", "1-KX0saEYfhZ9IJAOwabH38PCVtfXidJi", ), "small_datacube": ( @@ -221,7 +221,8 @@ def gdrive_download( kwargs = {"fuzzy": True} if id_ in file_ids: f = file_ids[id_] - filename = f[0] + # Use the name in the collection filename passed + filename = filename if filename is not None else f[0] kwargs["id"] = f[1] # if its not in the list of files we expect diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py index 20a3759a2..ff3d1c37c 100644 --- a/py4DSTEM/io/importfile.py +++ b/py4DSTEM/io/importfile.py @@ -75,7 +75,7 @@ def import_file( "gatan_K2_bin", "mib", "arina", - "abTEM" + "abTEM", # "kitware_counted", ], "Error: filetype not recognized" diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py index e1b7ab241..ddbea9005 100644 --- a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py +++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py @@ -361,7 +361,7 @@ def Array_to_h5(array, group): data = grp.create_dataset( "data", shape=array.data.shape, - data=array.data + data=array.data, # dtype = type(array.data) ) data.attrs.create( diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index fb4983622..f18d9b621 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -470,6 +470,61 @@ def filter_hot_pixels(datacube, thresh, ind_compare=1, return_mask=False): return datacube +def median_filter_masked_pixels(datacube, mask, kernel_width: int = 3): + """ + This function fixes a datacube where the same pixels are consistently + bad. It requires a mask that identifies all the bad pixels in the dataset. + Then for each diffraction pattern, a median kernel is applied around each + bad pixel with the specified width. + + Parameters + ---------- + datacube: + Datacube to be filtered + mask: + a boolean mask that specifies the bad pixels in the datacube + kernel_width (optional): + specifies the width of the median kernel + + Returns + ---------- + filtered datacube + """ + if kernel_width % 2 == 0: + width_max = kernel_width // 2 + width_min = kernel_width // 2 + + else: + width_max = int(kernel_width / 2 + 0.5) + width_min = int(kernel_width / 2 - 0.5) + + num_bad_pixels_indicies = np.array(np.where(mask)) + for a0 in range(num_bad_pixels_indicies.shape[1]): + index_x = num_bad_pixels_indicies[0, a0] + index_y = num_bad_pixels_indicies[1, a0] + + x_min = index_x - width_min + y_min = index_y - width_min + + x_max = index_x + width_max + y_max = index_y + width_max + + if x_min < 0: + x_min = 0 + if y_min < 0: + y_min = 0 + + if x_max > datacube.Qshape[0]: + x_max = datacube.Qshape[0] + if y_max > datacube.Qshape[1]: + y_max = datacube.Qshape[1] + + datacube.data[:, :, index_x, index_y] = np.median( + datacube.data[:, :, x_min:x_max, y_min:y_max], axis=(2, 3) + ) + return datacube + + def datacube_diffraction_shift( datacube, xshifts, @@ -518,7 +573,11 @@ def datacube_diffraction_shift( def resample_data_diffraction( - datacube, resampling_factor=None, output_size=None, method="bilinear" + datacube, + resampling_factor=None, + output_size=None, + method="bilinear", + conserve_array_sums=False, ): """ Performs diffraction space resampling of data by resampling_factor or to match output_size. @@ -539,7 +598,10 @@ def resample_data_diffraction( old_size = datacube.data.shape datacube.data = fourier_resample( - datacube.data, scale=resampling_factor, output_size=output_size + datacube.data, + scale=resampling_factor, + output_size=output_size, + conserve_array_sums=conserve_array_sums, ) if not resampling_factor: @@ -562,6 +624,10 @@ def resample_data_diffraction( if resampling_factor.shape == (): resampling_factor = np.tile(resampling_factor, 2) + output_size = np.round( + resampling_factor * np.array(datacube.shape[-2:]) + ).astype("int") + else: if output_size is None: raise ValueError( @@ -575,10 +641,28 @@ def resample_data_diffraction( resampling_factor = np.array(output_size) / np.array(datacube.shape[-2:]) - resampling_factor = np.concatenate(((1, 1), resampling_factor)) - datacube.data = zoom(datacube.data, resampling_factor, order=1) + output_data = np.zeros(datacube.Rshape + tuple(output_size)) + for Rx, Ry in tqdmnd( + datacube.shape[0], + datacube.shape[1], + desc="Resampling 4D datacube", + unit="DP", + unit_scale=True, + ): + output_data[Rx, Ry] = zoom( + datacube.data[Rx, Ry].astype(np.float32), + resampling_factor, + order=1, + mode="nearest", + grid_mode=True, + ) + + if conserve_array_sums: + output_data = output_data / resampling_factor.prod() + + datacube.data = output_data datacube.calibration.set_Q_pixel_size( - datacube.calibration.get_Q_pixel_size() / resampling_factor[2] + datacube.calibration.get_Q_pixel_size() / resampling_factor[0] ) else: diff --git a/py4DSTEM/process/calibration/origin.py b/py4DSTEM/process/calibration/origin.py index a0717e321..7f0c07a81 100644 --- a/py4DSTEM/process/calibration/origin.py +++ b/py4DSTEM/process/calibration/origin.py @@ -2,14 +2,27 @@ import functools import numpy as np +import matplotlib.pyplot as plt from scipy.ndimage import gaussian_filter from scipy.optimize import leastsq +import matplotlib.pyplot as plt from emdfile import tqdmnd, PointListArray from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration.probe import get_probe_size from py4DSTEM.process.fit import plane, parabola, bezier_two, fit_2D -from py4DSTEM.process.utils import get_CoM, add_to_2D_array_from_floats, get_maxima_2D +from py4DSTEM.process.utils import ( + get_CoM, + add_to_2D_array_from_floats, + get_maxima_2D, + upsampled_correlation, +) +from py4DSTEM.process.phase.utils import copy_to_device + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np # @@ -309,58 +322,122 @@ def get_origin( return qx0, qy0, mask -def get_origin_single_dp_beamstop(DP: np.ndarray, mask: np.ndarray, **kwargs): +def get_origin_friedel( + datacube: DataCube, + mask=None, + upsample_factor=1, + device="cpu", + return_cpu=True, +): """ - Find the origin for a single diffraction pattern, assuming there is a beam stop. - - Args: - DP (np array): diffraction pattern - mask (np array): boolean mask which is False under the beamstop and True - in the diffraction pattern. One approach to generating this mask - is to apply a suitable threshold on the average diffraction pattern - and use binary opening/closing to remove and holes - - Returns: - qx0, qy0 (tuple) measured center position of diffraction pattern + Fit the origin for each diffraction pattern, with or without a beam stop. + The method we have developed here is a heavily modified version of masked + cross correlation, where we use Friedel symmetry of the diffraction pattern + to find the common center. + + More details about how the correlation step can be found in: + https://doi.org/10.1109/TIP.2011.2181402 + + Parameters + ---------- + datacube: (DataCube) + The 4D dataset. + mask: (np array, optional) + Boolean mask which is False under the beamstop and True + in the diffraction pattern. One approach to generating this mask + is to apply a suitable threshold on the average diffraction pattern + and use binary opening/closing to remove any holes. + If no mask is provided, this method will likely not work with a beamstop. + upsample_factor: (int) + Upsample factor for subpixel fitting of the image shifts. + device: string + 'cpu' or 'gpu' to select device + return_cpu: bool + Return arrays on cpu. + + + Returns + ------- + qx0, qy0 + (tuple of np arrays) measured center position of each diffraction pattern """ - imCorr = np.real( - np.fft.ifft2( - np.fft.fft2(DP * mask) - * np.conj(np.fft.fft2(np.rot90(DP, 2) * np.rot90(mask, 2))) + # Select device + if device == "cpu": + xp = np + elif device == "gpu": + xp = cp + + # init measurement arrays + qx0 = xp.zeros(datacube.data.shape[:2]) + qy0 = xp.zeros_like(qx0) + + # pad the mask + if mask is not None: + mask = xp.asarray(mask).astype("float") + mask_pad = xp.pad( + mask, + ((0, datacube.data.shape[2]), (0, datacube.data.shape[3])), + constant_values=(1.0, 1.0), ) - ) + M = xp.fft.fft2(mask_pad) - xp, yp = np.unravel_index(np.argmax(imCorr), imCorr.shape) - - dx = ((xp + DP.shape[0] / 2) % DP.shape[0]) - DP.shape[0] / 2 - dy = ((yp + DP.shape[1] / 2) % DP.shape[1]) - DP.shape[1] / 2 - - return (DP.shape[0] + dx) / 2, (DP.shape[1] + dy) / 2 + # main loop over all probe positions + for rx, ry in tqdmnd(datacube.R_Nx, datacube.R_Ny): + if mask is None: + # pad image + im_xp = xp.asarray(datacube.data[rx, ry]) + im = xp.pad( + im_xp, + ((0, datacube.data.shape[2]), (0, datacube.data.shape[3])), + ) + G = xp.fft.fft2(im) + # Cross correlation of masked image with its inverse + cc = xp.real(xp.fft.ifft2(G**2)) -def get_origin_beamstop(datacube: DataCube, mask: np.ndarray, **kwargs): - """ - Find the origin for each diffraction pattern, assuming there is a beam stop. - - Args: - datacube (DataCube) - mask (np array): boolean mask which is False under the beamstop and True - in the diffraction pattern. One approach to generating this mask - is to apply a suitable threshold on the average diffraction pattern - and use binary opening/closing to remove any holes + else: + im_xp = xp.asarray(datacube.data[rx, ry, :, :]) + im = xp.pad( + im_xp, + ((0, datacube.data.shape[2]), (0, datacube.data.shape[3])), + ) - Returns: - qx0, qy0 (tuple of np arrays) measured center position of each diffraction pattern - """ + # Masked cross correlation of masked image with its inverse + term1 = xp.real(xp.fft.ifft2(xp.fft.fft2(im) ** 2) * xp.fft.ifft2(M**2)) + term2 = xp.real(xp.fft.ifft2(xp.fft.fft2(im**2) * M)) + term3 = xp.real(xp.fft.ifft2(xp.fft.fft2(im * mask_pad))) + cc = (term1 - term3) / (term2 - term3) - qx0 = np.zeros(datacube.data.shape[:2]) - qy0 = np.zeros_like(qx0) + # get correlation peak + x, y = xp.unravel_index(xp.argmax(cc), im.shape) - for rx, ry in tqdmnd(datacube.R_Nx, datacube.R_Ny): - x, y = get_origin_single_dp_beamstop(datacube.data[rx, ry, :, :], mask) + # half pixel upsampling / parabola subpixel fitting + dx = (cc[x + 1, y] - cc[x - 1, y]) / ( + 4.0 * cc[x, y] - 2.0 * cc[x + 1, y] - 2.0 * cc[x - 1, y] + ) + dy = (cc[x, y + 1] - cc[x, y - 1]) / ( + 4.0 * cc[x, y] - 2.0 * cc[x, y + 1] - 2.0 * cc[x, y - 1] + ) + # xp += np.round(dx*2.0)/2.0 + # yp += np.round(dy*2.0)/2.0 + x = x.astype("float") + dx + y = y.astype("float") + dy + + # upsample peak if needed + if upsample_factor > 1: + x, y = upsampled_correlation( + xp.fft.fft2(cc), + upsampleFactor=upsample_factor, + xyShift=xp.array((x, y)), + device=device, + ) - qx0[rx, ry] = x - qy0[rx, ry] = y + # Correlation peak, moved to image center shift + qx0[rx, ry] = (x / 2) % datacube.data.shape[2] + qy0[rx, ry] = (y / 2) % datacube.data.shape[3] - return qx0, qy0 + if return_cpu: + return copy_to_device(qx0), copy_to_device(qy0) + else: + return qx0, qy0 diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index 70110a977..eb964de96 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -221,9 +221,7 @@ def RI1(BI, BJ, G): ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ)) sub = np.logical_and(eps <= 0.1, G > 0.0) - temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log( - BJ / (BI + BJ) - ) + temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ)) temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2 temp -= 0.5 * (BI - BJ) ** 2 ri1[sub] += np.pi * G[sub] ** 2 * temp diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index c3a7519d0..e0fe59eee 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1,11 +1,13 @@ # Functions for calculating diffraction patterns, matching them to experiments, and creating orientation and phase maps. import numpy as np +from scipy.ndimage import gaussian_filter import matplotlib.pyplot as plt from matplotlib.patches import Circle from fractions import Fraction from typing import Union, Optional import sys +import warnings from emdfile import PointList from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom @@ -66,6 +68,7 @@ def __init__( positions, numbers, cell, + occupancy=None, ): """ Args: @@ -76,7 +79,7 @@ def __init__( 3 numbers: the three lattice parameters for an orthorhombic cell 6 numbers: the a,b,c lattice parameters and ɑ,β,ɣ angles for any cell 3x3 array: row vectors containing the (u,v,w) lattice vectors. - + occupancy (np.array): Partial occupancy values for each atomic site. Must match the length of positions """ # Initialize Crystal self.positions = np.asarray(positions) #: fractional atomic coordinates @@ -131,6 +134,17 @@ def __init__( else: raise Exception("Cell cannot contain " + np.size(cell) + " entries") + # occupancy + if occupancy is not None: + self.occupancy = np.array(occupancy) + # check the occupancy shape makes sense + if self.occupancy.shape[0] != self.positions.shape[0]: + raise Warning( + f"Number of occupancies ({self.occupancy.shape[0]}) and atomic positions ({self.positions.shape[0]}) do not match" + ) + else: + self.occupancy = np.ones(self.positions.shape[0], dtype=np.float32) + # pymatgen flag if "pymatgen" in sys.modules: self.pymatgen_available = True @@ -257,7 +271,70 @@ def get_strained_crystal( else: return crystal_strained - def from_CIF(CIF, conventional_standard_structure=True): + @staticmethod + def from_ase( + atoms, + ): + """ + Create a py4DSTEM Crystal object from an ASE atoms object + + Args: + atoms (ase.Atoms): an ASE atoms object + + """ + # get the occupancies from the atoms object + occupancies = ( + atoms.arrays["occupancies"] + if "occupancies" in atoms.arrays.keys() + else None + ) + + if "occupancy" in atoms.info.keys(): + warnings.warn( + "This Atoms object contains occupancy information but it will be ignored." + ) + + xtal = Crystal( + positions=atoms.get_scaled_positions(), # fractional coords + numbers=atoms.numbers, + cell=atoms.cell.array, + occupancy=occupancies, + ) + return xtal + + @staticmethod + def from_prismatic(filepath): + """ + Create a py4DSTEM Crystal object from an prismatic style xyz co-ordinate file + + Args: + filepath (str|Pathlib.Path): path to the prismatic format xyz file + + """ + + from ase import io + + # read the atoms using ase + atoms = io.read(filepath, format="prismatic") + + # get the occupancies from the atoms object + occupancies = ( + atoms.arrays["occupancies"] + if "occupancies" in atoms.arrays.keys() + else None + ) + xtal = Crystal( + positions=atoms.get_scaled_positions(), # fractional coords + numbers=atoms.numbers, + cell=atoms.cell.array, + occupancy=occupancies, + ) + return xtal + + @staticmethod + def from_CIF( + CIF, primitive: bool = True, conventional_standard_structure: bool = True + ): """ Create a Crystal object from a CIF file, using pymatgen to import the CIF @@ -273,12 +350,13 @@ def from_CIF(CIF, conventional_standard_structure=True): parser = CifParser(CIF) - structure = parser.get_structures(False)[0] + structure = parser.get_structures(primitive=primitive)[0] return Crystal.from_pymatgen_structure( structure, conventional_standard_structure=conventional_standard_structure ) + @staticmethod def from_pymatgen_structure( structure=None, formula=None, @@ -375,8 +453,6 @@ def from_pymatgen_structure( else selected["structure"] ) - positions = structure.frac_coords #: fractional atomic coordinates - cell = np.array( [ structure.lattice.a, @@ -388,10 +464,22 @@ def from_pymatgen_structure( ] ) - numbers = np.array([s.species.elements[0].Z for s in structure]) + site_data = np.array( + [ + (*site.frac_coords, elem.number, comp) + for site in structure + for elem, comp in site.species.items() + ] + ) + positions = site_data[:, :3] + numbers = site_data[:, 3] + occupancies = site_data[:, 4] - return Crystal(positions, numbers, cell) + return Crystal( + positions=positions, numbers=numbers, cell=cell, occupancy=occupancies + ) + @staticmethod def from_unitcell_parameters( latt_params, elements, @@ -575,10 +663,14 @@ def calculate_structure_factors( # Calculate structure factors self.struct_factors = np.zeros(np.size(self.g_vec_leng, 0), dtype="complex64") for a0 in range(self.positions.shape[0]): - self.struct_factors += f_all[:, a0] * np.exp( - (2j * np.pi) - * np.sum( - self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0 + self.struct_factors += ( + f_all[:, a0] + * self.occupancy[a0] + * np.exp( + (2j * np.pi) + * np.sum( + self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0 + ) ) ) @@ -876,6 +968,259 @@ def generate_ring_pattern( if return_calc is True: return radii_unique, intensity_unique + def generate_projected_potential( + self, + im_size=(256, 256), + pixel_size_angstroms=0.1, + potential_radius_angstroms=3.0, + sigma_image_blur_angstroms=0.1, + thickness_angstroms=100, + power_scale=1.0, + plot_result=False, + figsize=(6, 6), + orientation: Optional[Orientation] = None, + ind_orientation: Optional[int] = 0, + orientation_matrix: Optional[np.ndarray] = None, + zone_axis_lattice: Optional[np.ndarray] = None, + proj_x_lattice: Optional[np.ndarray] = None, + zone_axis_cartesian: Optional[np.ndarray] = None, + proj_x_cartesian: Optional[np.ndarray] = None, + ): + """ + Generate an image of the projected potential of crystal in real space, + using cell tiling, and a lookup table of the atomic potentials. + Note that we round atomic positions to the nearest pixel for speed. + + TODO - fix scattering prefactor so that output units are sensible. + + Parameters + ---------- + im_size: tuple, list, np.array + (2,) vector specifying the output size in pixels. + pixel_size_angstroms: float + Pixel size in Angstroms. + potential_radius_angstroms: float + Radius in Angstroms for how far to integrate the atomic potentials + sigma_image_blur_angstroms: float + Image blurring in Angstroms. + thickness_angstroms: float + Thickness of the sample in Angstroms. + Set thickness_thickness_angstroms = 0 to skip thickness projection. + power_scale: float + Power law scaling of potentials. Set to 2.0 to approximate Z^2 images. + plot_result: bool + Plot the projected potential image. + figsize: + (2,) vector giving the size of the output. + + orientation: Orientation + An Orientation class object + ind_orientation: int + If input is an Orientation class object with multiple orientations, + this input can be used to select a specific orientation. + orientation_matrix: array + (3,3) orientation matrix, where columns represent projection directions. + zone_axis_lattice: array + (3,) projection direction in lattice indices + proj_x_lattice: array) + (3,) x-axis direction in lattice indices + zone_axis_cartesian: array + (3,) cartesian projection direction + proj_x_cartesian: array + (3,) cartesian projection direction + + Returns + -------- + im_potential: (np.array) + Output image of the projected potential. + + """ + + # Determine image size in Angstroms + im_size = np.array(im_size) + im_size_Ang = im_size * pixel_size_angstroms + + # Parse orientation inputs + if orientation is not None: + if ind_orientation is None: + orientation_matrix = orientation.matrix[0] + else: + orientation_matrix = orientation.matrix[ind_orientation] + elif orientation_matrix is None: + orientation_matrix = self.parse_orientation( + zone_axis_lattice, proj_x_lattice, zone_axis_cartesian, proj_x_cartesian + ) + + # Rotate unit cell into projection direction + lat_real = self.lat_real.copy() @ orientation_matrix + + # Determine unit cell axes to tile over, by selecting 2/3 with largest in-plane component + inds_tile = np.argsort(np.linalg.norm(lat_real[:, 0:2], axis=1))[1:3] + m_tile = lat_real[inds_tile, :] + # Vector projected along optic axis + m_proj = np.squeeze(np.delete(lat_real, inds_tile, axis=0)) + + # Thickness + if thickness_angstroms > 0: + num_proj = np.round(thickness_angstroms / np.abs(m_proj[2])).astype("int") + if num_proj > 1: + vec_proj = m_proj[:2] / pixel_size_angstroms + shifts = np.arange(num_proj).astype("float") + shifts -= np.mean(shifts) + x_proj = shifts * vec_proj[0] + y_proj = shifts * vec_proj[1] + else: + num_proj = 1 + else: + num_proj = 1 + + # Determine tiling range + if thickness_angstroms > 0: + # include the cell height + dz = m_proj[2] * num_proj * 0.5 + p_corners = np.array( + [ + [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, dz], + [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, dz], + [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, dz], + [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, dz], + [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, -dz], + [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, -dz], + [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, -dz], + [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, -dz], + ] + ) + else: + p_corners = np.array( + [ + [-im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, -im_size_Ang[1] * 0.5, 0.0], + [-im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + [im_size_Ang[0] * 0.5, im_size_Ang[1] * 0.5, 0.0], + ] + ) + + ab = np.linalg.lstsq(m_tile[:, :2].T, p_corners[:, :2].T, rcond=None)[0] + ab = np.floor(ab) + a_range = np.array((np.min(ab[0]) - 1, np.max(ab[0]) + 2)) + b_range = np.array((np.min(ab[1]) - 1, np.max(ab[1]) + 2)) + + # Tile unit cell + a_ind, b_ind, atoms_ind = np.meshgrid( + np.arange(a_range[0], a_range[1]), + np.arange(b_range[0], b_range[1]), + np.arange(self.positions.shape[0]), + ) + abc_atoms = self.positions[atoms_ind.ravel(), :] + abc_atoms[:, inds_tile[0]] += a_ind.ravel() + abc_atoms[:, inds_tile[1]] += b_ind.ravel() + xyz_atoms_ang = abc_atoms @ lat_real + atoms_ID_all_0 = self.numbers[atoms_ind.ravel()] + + # Center atoms on image plane + x0 = xyz_atoms_ang[:, 0] / pixel_size_angstroms + im_size[0] / 2.0 + y0 = xyz_atoms_ang[:, 1] / pixel_size_angstroms + im_size[1] / 2.0 + + # if needed, tile atoms in the projection direction + if num_proj > 1: + x = (x0[:, None] + x_proj[None, :]).ravel() + y = (y0[:, None] + y_proj[None, :]).ravel() + atoms_ID_all = np.tile(atoms_ID_all_0, (num_proj, 1)) + else: + x = x0 + y = y0 + atoms_ID_all = atoms_ID_all_0 + # print(x.shape, y.shape) + + # delete atoms outside the field of view + bound = potential_radius_angstroms / pixel_size_angstroms + atoms_del = np.logical_or.reduce( + ( + x <= -bound, + y <= -bound, + x >= im_size[0] + bound, + y >= im_size[1] + bound, + ) + ) + x = np.delete(x, atoms_del) + y = np.delete(y, atoms_del) + atoms_ID_all = np.delete(atoms_ID_all, atoms_del) + + # Coordinate system for atomic projected potentials + potential_radius = np.ceil(potential_radius_angstroms / pixel_size_angstroms) + R = np.arange(0.5 - potential_radius, potential_radius + 0.5) + R_ind = R.astype("int") + R_2D = np.sqrt(R[:, None] ** 2 + R[None, :] ** 2) + + # Lookup table for atomic projected potentials + atoms_ID = np.unique(self.numbers) + atoms_lookup = np.zeros( + ( + atoms_ID.shape[0], + R_2D.shape[0], + R_2D.shape[1], + ) + ) + for a0 in range(atoms_ID.shape[0]): + atom_sf = single_atom_scatter([atoms_ID[a0]]) + atoms_lookup[a0, :, :] = atom_sf.projected_potential(atoms_ID[a0], R_2D) + + # if needed, apply gaussian blurring to each atom + if sigma_image_blur_angstroms > 0: + atoms_lookup[a0, :, :] = gaussian_filter( + atoms_lookup[a0, :, :], + sigma_image_blur_angstroms / pixel_size_angstroms, + mode="nearest", + ) + atoms_lookup **= power_scale + + # initialize potential + im_potential = np.zeros(im_size) + + # Add atoms to potential image + for a0 in range(atoms_ID_all.shape[0]): + ind = np.argmin(np.abs(atoms_ID - atoms_ID_all[a0])) + + x_ind = np.round(x[a0]).astype("int") + R_ind + y_ind = np.round(y[a0]).astype("int") + R_ind + x_sub = np.logical_and( + x_ind >= 0, + x_ind < im_size[0], + ) + y_sub = np.logical_and( + y_ind >= 0, + y_ind < im_size[1], + ) + im_potential[x_ind[x_sub][:, None], y_ind[y_sub][None, :]] += atoms_lookup[ + ind + ][x_sub][:, y_sub] + + if thickness_angstroms > 0: + im_potential /= num_proj + + if plot_result: + # quick plotting of the result + int_vals = np.sort(im_potential.ravel()) + int_range = np.array( + ( + int_vals[np.round(0.02 * int_vals.size).astype("int")], + int_vals[np.round(0.999 * int_vals.size).astype("int")], + ) + ) + + fig, ax = plt.subplots(figsize=figsize) + ax.imshow( + im_potential, + cmap="gray", + vmin=int_range[0], + vmax=int_range[1], + ) + # ax.scatter(y,x,c='r') # for testing + ax.set_axis_off() + ax.set_aspect("equal") + + return im_potential + # Vector conversions and other utilities for Crystal classes def cartesian_to_lattice(self, vec_cartesian): vec_lattice = self.lat_inv @ vec_cartesian diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 83dff29e1..fc6e691c3 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -76,6 +76,13 @@ def orientation_plan( progress_bar (bool): If false no progress bar is displayed """ + # Check to make sure user has calculated the structure factors if needed + if calculate_correlation_array: + if not hasattr(self, "g_vec_leng"): + raise ValueError( + "Run .calculate_structure_factors() before calculating an orientation plan." + ) + # Store inputs self.accel_voltage = np.asarray(accel_voltage) self.orientation_kernel_size = np.asarray(corr_kernel_size) @@ -579,25 +586,6 @@ def orientation_plan( 0, 2 * np.pi, self.orientation_in_plane_steps, endpoint=False ) - # Determine the radii of all spherical shells - radii_test = np.round(self.g_vec_leng / tol_distance) * tol_distance - radii = np.unique(radii_test) - # Remove zero beam - keep = np.abs(radii) > tol_distance - self.orientation_shell_radii = radii[keep] - - # init - self.orientation_shell_index = -1 * np.ones(self.g_vec_all.shape[1], dtype="int") - self.orientation_shell_count = np.zeros(self.orientation_shell_radii.size) - - # Assign each structure factor point to a radial shell - for a0 in range(self.orientation_shell_radii.size): - sub = np.abs(self.orientation_shell_radii[a0] - radii_test) <= tol_distance / 2 - - self.orientation_shell_index[sub] = a0 - self.orientation_shell_count[a0] = np.sum(sub) - self.orientation_shell_radii[a0] = np.mean(self.g_vec_leng[sub]) - # init storage arrays self.orientation_rotation_angles = np.zeros((self.orientation_num_zones, 3)) self.orientation_rotation_matrices = np.zeros((self.orientation_num_zones, 3, 3)) @@ -685,7 +673,32 @@ def orientation_plan( k0 = np.array([0.0, 0.0, -1.0 / self.wavelength]) n = np.array([0.0, 0.0, -1.0]) + # Remaining calculations are only required if we are computing the correlation array. if calculate_correlation_array: + # Determine the radii of all spherical shells + radii_test = np.round(self.g_vec_leng / tol_distance) * tol_distance + radii = np.unique(radii_test) + # Remove zero beam + keep = np.abs(radii) > tol_distance + self.orientation_shell_radii = radii[keep] + + # init + self.orientation_shell_index = -1 * np.ones( + self.g_vec_all.shape[1], dtype="int" + ) + self.orientation_shell_count = np.zeros(self.orientation_shell_radii.size) + + # Assign each structure factor point to a radial shell + for a0 in range(self.orientation_shell_radii.size): + sub = ( + np.abs(self.orientation_shell_radii[a0] - radii_test) + <= tol_distance / 2 + ) + + self.orientation_shell_index[sub] = a0 + self.orientation_shell_count[a0] = np.sum(sub) + self.orientation_shell_radii[a0] = np.mean(self.g_vec_leng[sub]) + # initialize empty correlation array self.orientation_ref = np.zeros( ( @@ -2025,6 +2038,9 @@ def calculate_strain( tol_intensity: float = 1e-4, k_max: Optional[float] = None, min_num_peaks=5, + intensity_weighting=False, + robust=True, + robust_thresh=3.0, rotation_range=None, mask_from_corr=True, corr_range=(0, 2), @@ -2039,24 +2055,46 @@ def calculate_strain( TODO: add robust fitting? - Args: - bragg_peaks_array (PointListArray): All Bragg peaks - orientation_map (OrientationMap): Orientation map generated from ACOM - corr_kernel_size (float): Correlation kernel size - if user does - not specify, uses self.corr_kernel_size. - sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms - tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion - tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots - k_max (float): Maximum scattering vector - min_num_peaks (int): Minimum number of peaks required. - rotation_range (float): Maximum rotation range in radians (for symmetry reduction). - progress_bar (bool): Show progress bar - mask_from_corr (bool): Use ACOM correlation signal for mask - corr_range (np.ndarray): Range of correlation signals for mask - corr_normalize (bool): Normalize correlation signal before masking + Parameters + ---------- + bragg_peaks_array (PointListArray): + All Bragg peaks + orientation_map (OrientationMap): + Orientation map generated from ACOM + corr_kernel_size (float): + Correlation kernel size - if user does + not specify, uses self.corr_kernel_size. + sigma_excitation_error (float): + sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms + tol_excitation_error_mult (float): + tolerance in units of sigma for s_g inclusion + tol_intensity (np float): + tolerance in intensity units for inclusion of diffraction spots + k_max (float): + Maximum scattering vector + min_num_peaks (int): + Minimum number of peaks required. + intensity_weighting: bool + Set to True to weight least squares by experimental peak intensity. + robust_fitting: bool + Set to True to use robust fitting, which performs outlier rejection. + robust_thresh: float + Threshold for robust fitting weights. + rotation_range (float): + Maximum rotation range in radians (for symmetry reduction). + progress_bar (bool): + Show progress bar + mask_from_corr (bool): + Use ACOM correlation signal for mask + corr_range (np.ndarray): + Range of correlation signals for mask + corr_normalize (bool): + Normalize correlation signal before masking - Returns: - strain_map (RealSlice): strain tensor + Returns + -------- + strain_map (RealSlice): + strain tensor """ @@ -2143,16 +2181,44 @@ def calculate_strain( (p_ref.data["qx"][inds_match[keep]], p_ref.data["qy"][inds_match[keep]]) ).T - # Apply intensity weighting from experimental measurements - qxy *= p.data["intensity"][keep, None] - qxy_ref *= p.data["intensity"][keep, None] - # Fit transformation matrix # Note - not sure about transpose here # (though it might not matter if rotation isn't included) - m = lstsq(qxy_ref, qxy, rcond=None)[0].T - - # Get the infinitesimal strain matrix + if intensity_weighting: + weights = np.sqrt(p.data["intensity"][keep, None]) * 0 + 1 + m = lstsq( + qxy_ref * weights, + qxy * weights, + rcond=None, + )[0].T + else: + m = lstsq( + qxy_ref, + qxy, + rcond=None, + )[0].T + + # Robust fitting + if robust: + for a0 in range(5): + # calculate new weights + qxy_fit = qxy_ref @ m + diff2 = np.sum((qxy_fit - qxy) ** 2, axis=1) + + weights = np.exp( + diff2 / ((-2 * robust_thresh**2) * np.median(diff2)) + )[:, None] + if intensity_weighting: + weights *= np.sqrt(p.data["intensity"][keep, None]) + + # calculate new fits + m = lstsq( + qxy_ref * weights, + qxy * weights, + rcond=None, + )[0].T + + # Set values into the infinitesimal strain matrix strain_map.get_slice("e_xx").data[rx, ry] = 1 - m[0, 0] strain_map.get_slice("e_yy").data[rx, ry] = 1 - m[1, 1] strain_map.get_slice("e_xy").data[rx, ry] = -(m[0, 1] + m[1, 0]) / 2.0 @@ -2160,7 +2226,7 @@ def calculate_strain( # Add finite rotation from ACOM orientation map. # I am not sure about the relative signs here. - # Also, I need to add in the mirror operator. + # Also, maybe I need to add in the mirror operator? if orientation_map.mirror[rx, ry, 0]: strain_map.get_slice("theta").data[rx, ry] += ( orientation_map.angles[rx, ry, 0, 0] @@ -2192,6 +2258,7 @@ def save_ang_file( pixel_units="px", transpose_xy=True, flip_x=False, + flip_y=False, ): """ This function outputs an ascii text file in the .ang format, containing @@ -2211,8 +2278,10 @@ def save_ang_file( nothing """ - - from orix.io.plugins.ang import file_writer + try: + from orix.io.plugins.ang import file_writer + except ImportError: + raise Exception("orix failed to import; try pip installing separately") xmap = self.orientation_map_to_orix_CrystalMap( orientation_map, @@ -2222,6 +2291,7 @@ def save_ang_file( return_color_key=False, transpose_xy=transpose_xy, flip_x=flip_x, + flip_y=flip_y, ) file_writer(file_name, xmap) @@ -2235,6 +2305,7 @@ def orientation_map_to_orix_CrystalMap( pixel_units="px", transpose_xy=True, flip_x=False, + flip_y=False, return_color_key=False, ): try: @@ -2260,12 +2331,20 @@ def orientation_map_to_orix_CrystalMap( import warnings - # Get orientation matrices + # Get orientation matrices and correlation signal (will be used as iq and ci) orientation_matrices = orientation_map.matrix[:, :, ind_orientation].copy() + corr_values = orientation_map.corr[:, :, ind_orientation].copy() + + # Check for transpose if transpose_xy: orientation_matrices = np.transpose(orientation_matrices, (1, 0, 2, 3)) + corr_values = np.transpose(corr_values, (1, 0)) if flip_x: orientation_matrices = np.flip(orientation_matrices, axis=0) + corr_values = np.flip(corr_values, axis=0) + if flip_y: + orientation_matrices = np.flip(orientation_matrices, axis=1) + corr_values = np.flip(corr_values, axis=1) # Convert the orientation matrices into Euler angles # suppress Gimbal lock warnings @@ -2327,8 +2406,8 @@ def fxn(): y=coords["y"], phase_list=PhaseList(phase), prop={ - "iq": orientation_map.corr[:, :, ind_orientation].ravel(), - "ci": orientation_map.corr[:, :, ind_orientation].ravel(), + "iq": corr_values.ravel(), + "ci": corr_values.ravel(), }, scan_unit=pixel_units, ) diff --git a/py4DSTEM/process/diffraction/crystal_bloch.py b/py4DSTEM/process/diffraction/crystal_bloch.py index 6a3c9b1ac..ce8bb8622 100644 --- a/py4DSTEM/process/diffraction/crystal_bloch.py +++ b/py4DSTEM/process/diffraction/crystal_bloch.py @@ -27,7 +27,6 @@ def calculate_dynamical_structure_factors( tol_structure_factor: float = 0.0, recompute_kinematic_structure_factors=True, g_vec_precision=None, - verbose=True, ): """ Calculate and store the relativistic corrected structure factors used for Bloch computations @@ -92,7 +91,7 @@ def calculate_dynamical_structure_factors( # Calculate the reciprocal lattice points to include based on k_max - k_max = np.asarray(k_max) + k_max: np.ndarray = np.asarray(k_max) if recompute_kinematic_structure_factors: if hasattr(self, "struct_factors"): @@ -215,7 +214,9 @@ def get_f_e(q, Z, thermal_sigma, method): # Calculate structure factors struct_factors = np.sum( - f_e * np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)), + f_e + * self.occupancy[:, None] + * np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)), axis=0, ) diff --git a/py4DSTEM/process/diffraction/crystal_calibrate.py b/py4DSTEM/process/diffraction/crystal_calibrate.py index c068bf79e..b15015c62 100644 --- a/py4DSTEM/process/diffraction/crystal_calibrate.py +++ b/py4DSTEM/process/diffraction/crystal_calibrate.py @@ -21,7 +21,7 @@ def calibrate_pixel_size( k_max=None, k_step=0.002, k_broadening=0.002, - fit_all_intensities=True, + fit_all_intensities=False, set_calibration_in_place=False, verbose=True, plot_result=False, @@ -50,7 +50,7 @@ def calibrate_pixel_size( k_broadening (float): Initial guess for Gaussian broadening of simulated pattern (Å^-1) fit_all_intensities (bool): Set to true to allow all peak intensities to - change independently False forces a single intensity scaling. + change independently. False forces a single intensity scaling for all peaks. set_calibration (bool): if True, set the fit pixel size to the calibration metadata, and calibrate bragg_peaks verbose (bool): Output the calibrated pixel size. @@ -138,7 +138,7 @@ def fit_profile(k, *coefs): if returnfig: fig, ax = self.plot_scattering_intensity( - bragg_peaks=bragg_peaks, + bragg_peaks=bragg_peaks_cali, figsize=figsize, k_broadening=k_broadening, int_power_scale=1.0, @@ -151,7 +151,7 @@ def fit_profile(k, *coefs): ) else: self.plot_scattering_intensity( - bragg_peaks=bragg_peaks, + bragg_peaks=bragg_peaks_cali, figsize=figsize, k_broadening=k_broadening, int_power_scale=1.0, diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 94cf75b8c..47df2e6ca 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -3,6 +3,8 @@ from matplotlib.axes import Axes import matplotlib.tri as mtri from mpl_toolkits.mplot3d import Axes3D, art3d +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + from scipy.signal import medfilt from scipy.ndimage import gaussian_filter from scipy.ndimage import distance_transform_edt @@ -91,18 +93,26 @@ def plot_structure( # Fractional atomic coordinates pos = self.positions + occ = self.occupancy + # x tile sub = pos[:, 0] < tol_distance pos = np.vstack([pos, pos[sub, :] + np.array([1, 0, 0])]) ID = np.hstack([ID, ID[sub]]) + if occ is not None: + occ = np.hstack([occ, occ[sub]]) # y tile sub = pos[:, 1] < tol_distance pos = np.vstack([pos, pos[sub, :] + np.array([0, 1, 0])]) ID = np.hstack([ID, ID[sub]]) + if occ is not None: + occ = np.hstack([occ, occ[sub]]) # z tile sub = pos[:, 2] < tol_distance pos = np.vstack([pos, pos[sub, :] + np.array([0, 0, 1])]) ID = np.hstack([ID, ID[sub]]) + if occ is not None: + occ = np.hstack([occ, occ[sub]]) # Cartesian atomic positions xyz = pos @ self.lat_real @@ -141,17 +151,109 @@ def plot_structure( # atoms ID_all = np.unique(ID) - for ID_plot in ID_all: - sub = ID == ID_plot - ax.scatter( - xs=xyz[sub, 1], # + d[0], - ys=xyz[sub, 0], # + d[1], - zs=xyz[sub, 2], # + d[2], - s=size_marker, - linewidth=2, - facecolors=atomic_colors(ID_plot), - edgecolor=[0, 0, 0], - ) + if occ is None: + for ID_plot in ID_all: + sub = ID == ID_plot + ax.scatter( + xs=xyz[sub, 1], # + d[0], + ys=xyz[sub, 0], # + d[1], + zs=xyz[sub, 2], # + d[2], + s=size_marker, + linewidth=2, + facecolors=atomic_colors(ID_plot), + edgecolor=[0, 0, 0], + ) + else: + # init + tol = 1e-4 + num_seg = 180 + radius = 0.7 + zp = np.zeros(num_seg + 1) + + mark = np.ones(xyz.shape[0], dtype="bool") + for a0 in range(xyz.shape[0]): + if mark[a0]: + xyz_plot = xyz[a0, :] + inds = np.argwhere(np.sum((xyz - xyz_plot) ** 2, axis=1) < tol) + occ_plot = occ[inds] + mark[inds] = False + ID_plot = ID[inds] + + if np.sum(occ_plot) < 1.0: + occ_plot = np.append(occ_plot, 1 - np.sum(occ_plot)) + ID_plot = np.append(ID_plot, -1) + else: + occ_plot = occ_plot[0] + ID_plot = ID_plot[0] + + # Plot site as series of filled arcs + theta0 = 0 + for a1 in range(occ_plot.shape[0]): + theta1 = theta0 + occ_plot[a1] * 2.0 * np.pi + theta = np.linspace(theta0, theta1, num_seg + 1) + xp = np.cos(theta) * radius + yp = np.sin(theta) * radius + + # Rotate towards camera + xyz_rot = np.vstack((xp.ravel(), yp.ravel(), zp.ravel())) + if occ_plot[a1] < 1.0: + xyz_rot = np.append( + xyz_rot, np.array((0, 0, 0))[:, None], axis=1 + ) + xyz_rot = orientation_matrix @ xyz_rot + + # add to plot + verts = [ + list( + zip( + xyz_rot[1, :] + xyz_plot[1], + xyz_rot[0, :] + xyz_plot[0], + xyz_rot[2, :] + xyz_plot[2], + ) + ) + ] + # ax.add_collection3d( + # Poly3DCollection( + # verts + # ) + # ) + collection = Poly3DCollection( + verts, + linewidths=2.0, + alpha=1.0, + edgecolors="k", + ) + face_color = [ + 0.5, + 0.5, + 1, + ] # alternative: matplotlib.colors.rgb2hex([0.5, 0.5, 1]) + if ID_plot[a1] == -1: + collection.set_facecolor((1.0, 1.0, 1.0)) + else: + collection.set_facecolor(atomic_colors(ID_plot[a1])) + ax.add_collection3d(collection) + + # update start point + if a1 < occ_plot.size: + theta0 = theta1 + + # for ID_plot in ID_all: + # sub = ID == ID_plot + # ax.scatter( + # xs=xyz[sub, 1], # + d[0], + # ys=xyz[sub, 0], # + d[1], + # zs=xyz[sub, 2], # + d[2], + # s=size_marker, + # linewidth=2, + # facecolors='none', + # edgecolor=[0, 0, 0], + # ) + # poly = PolyCollection( + # verts, + # facecolors=['r', 'g', 'b', 'y'], + # alpha = 0.6) + # ax.add_collection3d(poly, zs=zs, zdir='y') # plot limit if plot_limit is None: @@ -352,8 +454,7 @@ def plot_scattering_intensity( int_sf_plot = calc_1D_profile( k, self.g_vec_leng, - (self.struct_factors_int**int_power_scale) - * (self.g_vec_leng**k_power_scale), + (self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale), remove_origin=True, k_broadening=k_broadening, int_scale=int_scale, diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 9973ff79f..5c2d56a3c 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -169,8 +169,7 @@ def polar_gaussian_2D( # t2 = np.min(np.vstack([t,1-t])) t2 = np.square(t - mu_t) return ( - I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) - + C + I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C ) diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 1005a619d..ecfeaa1d2 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -2,15 +2,15 @@ _emd_hook = True -from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction -from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction -from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction -from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction -from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.dpc import DPC +from py4DSTEM.process.phase.magnetic_ptychographic_tomography import MagneticPtychographicTomography +from py4DSTEM.process.phase.magnetic_ptychography import MagneticPtychography +from py4DSTEM.process.phase.mixedstate_multislice_ptychography import MixedstateMultislicePtychography +from py4DSTEM.process.phase.mixedstate_ptychography import MixedstatePtychography +from py4DSTEM.process.phase.multislice_ptychography import MultislicePtychography +from py4DSTEM.process.phase.parallax import Parallax +from py4DSTEM.process.phase.ptychographic_tomography import PtychographicTomography +from py4DSTEM.process.phase.singleslice_ptychography import SingleslicePtychography from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/dpc.py similarity index 85% rename from py4DSTEM/process/phase/iterative_dpc.py rename to py4DSTEM/process/phase/dpc.py index 11adc0c70..a9468a002 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -19,12 +19,11 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction +from py4DSTEM.visualize.vis_special import return_scaled_histogram_ordering -warnings.simplefilter(action="always", category=UserWarning) - -class DPCReconstruction(PhaseReconstruction): +class DPC(PhaseReconstruction): """ Iterative Differential Phase Constrast Reconstruction Class. @@ -42,7 +41,11 @@ class DPCReconstruction(PhaseReconstruction): verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls name: str, optional Class name """ @@ -54,24 +57,17 @@ def __init__( energy: float = None, verbose: bool = True, device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "dpc_reconstruction", ): Custom.__init__(self, name=name) - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + if storage is None: + storage = device - self._gaussian_filter = gaussian_filter - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) self.set_save_defaults() @@ -82,7 +78,6 @@ def __init__( # Metadata self._energy = energy self._verbose = verbose - self._device = device self._preprocessed = False def to_h5(self, group): @@ -198,6 +193,8 @@ def _get_constructor_args(cls, group): "name": instance_md["name"], "verbose": True, # for compatibility "device": "cpu", # for compatibility + "storage": "cpu", # for compatibility + "clear_fft_cache": True, # for compatibility } return kwargs @@ -234,15 +231,18 @@ def preprocess( self, dp_mask: np.ndarray = None, padding_factor: float = 2, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, maximize_divergence: bool = False, fit_function: str = "plane", force_com_rotation: float = None, force_com_transpose: bool = None, force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None, + vectorized_com_calculation: bool = True, force_com_measured: Sequence[np.ndarray] = None, plot_center_of_mass: str = "default", plot_rotation: bool = True, + device: str = None, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -271,6 +271,8 @@ def preprocess( Force whether diffraction intensities need to be transposed. force_com_shifts: tuple of ndarrays (CoMx, CoMy) Force CoM fitted shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) Force CoM measured shifts plot_center_of_mass: str, optional @@ -284,7 +286,12 @@ def preprocess( self: DPCReconstruction Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device + storage = self._storage # set additional metadata self._dp_mask = dp_mask @@ -303,7 +310,7 @@ def preprocess( data=np.empty(force_com_measured[0].shape + (1, 1)) ) - self._intensities = self._extract_intensities_and_calibrations_from_datacube( + _intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=False, ) @@ -316,10 +323,11 @@ def preprocess( self._com_normalized_x, self._com_normalized_y, ) = self._calculate_intensities_center_of_mass( - self._intensities, + _intensities, dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, com_measured=force_com_measured, ) @@ -328,8 +336,6 @@ def preprocess( self._rotation_best_transpose, self._com_x, self._com_y, - self.com_x, - self.com_y, ) = self._solve_for_center_of_mass_relative_rotation( self._com_measured_x, self._com_measured_y, @@ -344,11 +350,23 @@ def preprocess( **kwargs, ) + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + ] + self.copy_attributes_to_device(attrs, storage) + # Object Initialization padded_object_shape = np.round( np.array(self._grid_scan_shape) * padding_factor ).astype("int") self._padded_object_phase = xp.zeros(padded_object_shape, dtype=xp.float32) + if self._object_phase is not None: self._padded_object_phase[ : self._grid_scan_shape[0], : self._grid_scan_shape[1] @@ -357,20 +375,23 @@ def preprocess( self._padded_object_phase_initial = self._padded_object_phase.copy() # Fourier coordinates and operators - kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0]) - ky = xp.fft.fftfreq(padded_object_shape[1], d=self._scan_sampling[1]) + kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0]).astype( + xp.float32 + ) + ky = xp.fft.fftfreq(padded_object_shape[1], d=self._scan_sampling[1]).astype( + xp.float32 + ) kya, kxa = xp.meshgrid(ky, kx) + k_den = kxa**2 + kya**2 k_den[0, 0] = np.inf k_den = 1 / k_den + self._kx_op = -1j * 0.25 * kxa * k_den self._ky_op = -1j * 0.25 * kya * k_den self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -413,6 +434,7 @@ def _forward( """ xp = self._xp + asnumpy = self._asnumpy dx, dy = self._scan_sampling # centered finite-differences @@ -431,8 +453,9 @@ def _forward( obj_dx[mask_inv] = 0 obj_dy[mask_inv] = 0 - new_error = xp.mean(obj_dx[mask] ** 2 + obj_dy[mask] ** 2) / ( - xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2) + new_error = asnumpy( + xp.mean(obj_dx[mask] ** 2 + obj_dy[mask] ** 2) + / (xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2)) ) return obj_dx, obj_dy, new_error, step_size @@ -516,9 +539,9 @@ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): constrained_object: np.ndarray Constrained object estimate """ - gaussian_filter = self._gaussian_filter - + gaussian_filter = self._scipy.ndimage.gaussian_filter gaussian_filter_sigma /= self.sampling[0] + current_object = gaussian_filter(current_object, gaussian_filter_sigma) return current_object @@ -558,44 +581,13 @@ def _object_butterworth_constraint( if q_lowpass: env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - current_object_mean = xp.mean(current_object) + current_object_mean = xp.mean(current_object, axis=(-2, -1), keepdims=True) current_object -= current_object_mean current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) current_object += current_object_mean return xp.real(current_object) - def _object_anti_gridding_contraint(self, current_object): - """ - Zero outer pixels of object fft to remove gridding artifacts - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - # find indices to zero - width_x = current_object.shape[0] - width_y = current_object.shape[1] - ind_min_x = int(xp.floor(width_x / 2) - 2) - ind_max_x = int(xp.ceil(width_x / 2) + 2) - ind_min_y = int(xp.floor(width_y / 2) - 2) - ind_max_y = int(xp.ceil(width_y / 2) + 2) - - # zero pixels - object_fft = xp.fft.fft2(current_object) - object_fft[ind_min_x:ind_max_x] = 0 - object_fft[:, ind_min_y:ind_max_y] = 0 - - return xp.real(xp.fft.ifft2(object_fft)) - def _constraints( self, current_object, @@ -605,7 +597,6 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, - anti_gridding, ): """ DPC constraints operator. @@ -626,9 +617,6 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - anti_gridding: bool - If true, zero outer pixels of object fft to remove - gridding artifacts Returns -------- @@ -648,11 +636,6 @@ def _constraints( butterworth_order, ) - if anti_gridding: - current_object = self._object_anti_gridding_contraint( - current_object, - ) - return current_object def reconstruct( @@ -664,13 +647,14 @@ def reconstruct( backtrack: bool = True, progress_bar: bool = True, gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - butterworth_filter_iter: int = np.inf, + gaussian_filter: bool = True, + butterworth_filter: bool = True, q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, - anti_gridding: float = True, store_iterations: bool = False, + device: str = None, + clear_fft_cache: bool = None, ): """ Performs Iterative DPC Reconstruction: @@ -693,21 +677,22 @@ def reconstruct( If True, reconstruction progress bar will be printed gaussian_filter_sigma: float, optional Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter - anti_gridding: bool - If true, zero outer pixels of object fft to remove - gridding artifacts store_iterations: bool, optional If True, all reconstruction iterations will be stored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -715,18 +700,34 @@ def reconstruct( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + xp = self._xp + device = self._device asnumpy = self._asnumpy # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - if reset: + if reset is True: self.error = np.inf self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -772,8 +773,6 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - if self._verbose: - print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -788,18 +787,17 @@ def reconstruct( # constraints self._padded_object_phase = self._constraints( self._padded_object_phase, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter + butterworth_filter=butterworth_filter and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, - anti_gridding=anti_gridding, ) self.error_iterations.append(self.error.item()) + if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -822,9 +820,7 @@ def reconstruct( ] self.object_phase = asnumpy(self._object_phase) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) return self @@ -846,6 +842,8 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) if plot_convergence: spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) @@ -863,10 +861,15 @@ def _visualize_last_iteration( ] ax1 = fig.add_subplot(spec[0]) - im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) + + obj, vmin, vmax = return_scaled_histogram_ordering( + self.object_phase, vmin, vmax + ) + im = ax1.imshow(obj, extent=extent, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) + ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title("Reconstructed object phase") if cbar: divider = make_axes_locatable(ax1) @@ -878,10 +881,12 @@ def _visualize_last_iteration( errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -906,7 +911,6 @@ def _visualize_all_iterations( iterations_grid: Tuple[int,int] Grid dimensions to plot reconstruction iterations """ - if not hasattr(self, "object_phase_iterations"): raise ValueError( ( @@ -915,31 +919,41 @@ def _visualize_all_iterations( ) ) - if iterations_grid == "auto": - num_iter = len(self.error_iterations) + num_iter = len(self.object_phase_iterations) + if iterations_grid == "auto": if num_iter == 1: return self._visualize_last_iteration( + fig=fig, plot_convergence=plot_convergence, cbar=cbar, **kwargs, ) + else: iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if iterations_grid[0] * iterations_grid[1] > num_iter: + raise ValueError() + auto_figsize = ( (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) if plot_convergence else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) + figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + max_iter = num_iter - 1 total_grids = np.prod(iterations_grid) - errors = self.error_iterations - phases = self.object_phase_iterations - max_iter = len(phases) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + + errors = np.array(self.error_iterations)[-num_iter:] + objects = [self.object_phase_iterations[n] for n in grid_range] extent = [ 0, @@ -966,25 +980,30 @@ def _visualize_all_iterations( ) for n, ax in enumerate(grid): + obj, vmin_n, vmax_n = return_scaled_histogram_ordering( + objects[n], vmin=vmin, vmax=vmax + ) im = ax.imshow( - phases[grid_range[n]], + obj, extent=extent, cmap=cmap, + vmin=vmin_n, + vmax=vmax_n, **kwargs, ) + ax.set_ylabel(f"x [{self._scan_units[0]}]") ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.set_title(f"Iter: {grid_range[n]} phase") + if cbar: grid.cbar_axes[n].colorbar(im) - ax.set_title( - f"Iteration: {grid_range[n]}\nNMSE error: {errors[grid_range[n]]:.3e}" - ) if plot_convergence: ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(len(errors)), errors, **kwargs) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_xlabel("Iteration number") - ax2.set_ylabel("Log NMSE error") + ax2.set_ylabel("NMSE error") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -1030,6 +1049,8 @@ def visualize( **kwargs, ) + self.clear_device_mem(self._device, self._clear_fft_cache) + return self @property diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py deleted file mode 100644 index 10dc40e00..000000000 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ /dev/null @@ -1,3705 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely multislice ptychography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple, Union - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = None - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar -from scipy.ndimage import rotate - -warnings.simplefilter(action="always", category=UserWarning) - - -class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): - """ - Mixed-State Multislice Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (N,Sx,Sy) - Reconstructed object dimensions : (T,Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes - and (Px,Py) is the padded-object size we position our ROI around in - each of the T slices. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - num_probes: int, optional - Number of mixed-state probes - num_slices: int - Number of slices to use in the forward model - slice_thicknesses: float or Sequence[float] - Slice thicknesses in angstroms. If float, all slices are assigned the same thickness - datacube: DataCube, optional - Input 4D diffraction pattern intensities - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - middle_focus: bool - if True, adds half the sample thickness to the defocus - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") - - def __init__( - self, - energy: float, - num_slices: int, - slice_thicknesses: Union[float, Sequence[float]], - num_probes: int = None, - datacube: DataCube = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - theta_x: float = 0, - theta_y: float = 0, - middle_focus: bool = False, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "multi-slice_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): - if num_probes is None: - raise ValueError( - ( - "If initial_probe_guess is None, or a ComplexProbe object, " - "num_probes must be specified." - ) - ) - else: - if len(initial_probe_guess.shape) != 3: - raise ValueError( - "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." - ) - num_probes = initial_probe_guess.shape[0] - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - if np.isscalar(slice_thicknesses): - mean_slice_thickness = slice_thicknesses - else: - mean_slice_thickness = np.mean(slice_thicknesses) - - if middle_focus: - if "defocus" in kwargs: - kwargs["defocus"] += mean_slice_thickness * num_slices / 2 - elif "C10" in kwargs: - kwargs["C10"] -= mean_slice_thickness * num_slices / 2 - elif polar_parameters is not None and "defocus" in polar_parameters: - polar_parameters["defocus"] = ( - polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 - ) - elif polar_parameters is not None and "C10" in polar_parameters: - polar_parameters["C10"] = ( - polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 - ) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - slice_thicknesses = np.array(slice_thicknesses) - if slice_thicknesses.shape == (): - slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) - elif slice_thicknesses.shape[0] != (num_slices - 1): - raise ValueError( - ( - f"slice_thicknesses must have length {num_slices - 1}, " - f"not {slice_thicknesses.shape[0]}." - ) - ) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._positions_mask = positions_mask - self._object_padding_px = object_padding_px - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_probes = num_probes - self._num_slices = num_slices - self._slice_thicknesses = slice_thicknesses - self._theta_x = theta_x - self._theta_y = theta_y - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - theta_x: float, - theta_y: float, - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - - theta_x = np.deg2rad(theta_x) - theta_y = np.deg2rad(theta_y) - - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (T,Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: MixedstateMultislicePtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None or isinstance(self._probe, ComplexProbe): - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - _probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - else: - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - _probe = self._probe._array - else: - self._probe._xp = xp - _probe = self._probe.build()._array - - self._probe = xp.zeros( - (self._num_probes,) + tuple(self._region_of_interest_shape), - dtype=xp.complex64, - ) - sx, sy = self._region_of_interest_shape - self._probe[0] = _probe - - # Randomly shift phase of other probes - for i_probe in range(1, self._num_probes): - shift_x = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) - ) - shift_y = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) - ) - self._probe[i_probe] = ( - self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = None # Doesn't really make sense for mixed-state - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - self._theta_x, - self._theta_y, - ) - - # overlaps - shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered[0], - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe[0].copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe[0] intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax2, chroma_boost=chroma_boost) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe[0] intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ] - - num_probe_positions = object_patches.shape[1] - - propagated_shape = ( - self._num_slices, - num_probe_positions, - self._num_probes, - self._region_of_interest_shape[0], - self._region_of_interest_shape[1], - ) - propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = ( - xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - intensity_norm[intensity_norm == 0.0] = np.inf - amplitude_modification = amplitudes / intensity_norm - - fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves - modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf - - amplitude_modification = amplitudes / intensity_norm_projected - fourier_projected_factor *= amplitude_modification[:, None] - - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = xp.zeros_like(current_object[s]) - object_update = xp.zeros_like(current_object[s]) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 - ) - - if self._object_type == "potential": - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - ) - else: - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] - ) - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] += object_update * probe_normalization - - # back-transmit - exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = xp.zeros_like(current_object[s]) - object_update = xp.zeros_like(current_object[s]) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(probe[:, i_probe]) ** 2 - ) - - if self._object_type == "potential": - object_update += self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves_copy[:, i_probe] - ) - ) - else: - object_update += self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] = object_update * probe_normalization - - # back-transmit - exit_waves_copy *= xp.expand_dims( - xp.conj(obj), axis=1 - ) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - probe_intensity = xp.abs(current_probe[0]) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_orthogonalization_constraint(self, current_probe): - """ - Ptychographic probe-orthogonalization constraint. - Used to ensure mixed states are orthogonal to each other. - Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Orthogonalized probe estimate - """ - xp = self._xp - n_probes = self._num_probes - - # compute upper half of P* @ P - pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) - - for i in range(n_probes): - for j in range(i, n_probes): - pairwise_dot_product[i, j] = xp.sum( - current_probe[i].conj() * current_probe[j] - ) - - # compute eigenvectors (effectively cheaper way of computing V* from SVD) - _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") - current_probe = xp.tensordot(evecs.T, current_probe, axes=1) - - # sort by real-space intensity - intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) - intensities_order = xp.argsort(intensities, axis=None)[::-1] - return current_probe[intensities_order] - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - 2D Butterworth filter - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_kz_regularization_constraint( - self, current_object, kz_regularization_gamma - ): - """ - Arctan regularization filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - kz_regularization_gamma: float - Slice regularization strength - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - current_object = xp.pad( - current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" - ) - - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) - - kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] - - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qz2 = qza**2 * kz_regularization_gamma**2 - qr2 = qxa**2 + qya**2 - - w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) - - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) - current_object = current_object[1:] - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_identical_slices_constraint(self, current_object): - """ - Strong regularization forcing all slices to be identical - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - object_mean = current_object.mean(0, keepdims=True) - current_object[:] = object_mean - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - orthogonalize_probe, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - orthogonalize_probe: bool - If True, probe will be orthogonalized - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, kz_regularization_gamma - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - pad_object=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - orthogonalize_probe: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - kz_regularization_filter_iter: int = np.inf, - kz_regularization_gamma: Union[float, np.ndarray] = None, - identical_slices_iter: int = 0, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - pure_phase_object_iter: int = 0, - tv_denoise_iter_chambolle=np.inf, - tv_denoise_weight_chambolle=None, - tv_denoise_pad_chambolle=True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter_iter: int, optional - Number of iterations to run using kz regularization filter - kz_regularization_gamma, float, optional - kz regularization strength - identical_slices_iter: int, optional - Number of iterations to run using identical slices - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - tv_denoise_iter_chambolle: bool - Number of iterations with TV denoisining - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: MultislicePtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - propagated_probes, - object_patches, - self._transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - self._probe[0], - self._transmitted_probes[:, 0], - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - kz_regularization_filter=a0 < kz_regularization_filter_iter - and kz_regularization_gamma is not None, - kz_regularization_gamma=kz_regularization_gamma[a0] - if kz_regularization_gamma is not None - and isinstance(kz_regularization_gamma, np.ndarray) - else kz_regularization_gamma, - identical_slices=a0 < identical_slices_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle - and tv_denoise_weight_chambolle is not None, - tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - orthogonalize_probe=orthogonalize_probe, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual[0] - else: - probe_array = self.probe_fourier[0] - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe[0], power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append( - self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) - ) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]][0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - return self - - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - if probe is None: - probe = list( - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - ) - else: - if isinstance(probe, np.ndarray) and probe.ndim == 2: - probe = [probe] - probe = [ - asnumpy( - self._return_fourier_probe( - pr, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for pr in probe - ] - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - chroma_boost = kwargs.pop("chroma_boost", 1) - - show_complex( - probe if len(probe) > 1 else probe[0], - cbar=cbar, - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - - def show_transmitted_probe( - self, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations=False, - **kwargs, - ): - """ - Plots the min, max, and mean transmitted probe after propagation and transmission. - - Parameters - ---------- - plot_fourier_probe: boolean, optional - If True, the transmitted probes are also plotted in Fourier space - kwargs: - Passed to show_complex - """ - - xp = self._xp - asnumpy = self._asnumpy - - transmitted_probe_intensities = xp.sum( - xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) - ) - min_intensity_transmitted = self._transmitted_probes[ - xp.argmin(transmitted_probe_intensities), 0 - ] - max_intensity_transmitted = self._transmitted_probes[ - xp.argmax(transmitted_probe_intensities), 0 - ] - mean_transmitted = self._transmitted_probes[:, 0].mean(0) - probes = [ - asnumpy(self._return_centered_probe(probe)) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - title = [ - "Mean Transmitted Probe", - "Min Intensity Transmitted Probe", - "Max Intensity Transmitted Probe", - ] - - if plot_fourier_probe: - bottom_row = [ - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - probes = [probes, bottom_row] - - title += [ - "Mean Transmitted Fourier Probe", - "Min Intensity Transmitted Fourier Probe", - "Max Intensity Transmitted Fourier Probe", - ] - - title = kwargs.get("title", title) - show_complex( - probes, - title=title, - **kwargs, - ) - - def show_slices( - self, - ms_object=None, - cbar: bool = True, - common_color_scale: bool = True, - padding: int = 0, - num_cols: int = 3, - show_fft: bool = False, - **kwargs, - ): - """ - Displays reconstructed slices of object - - Parameters - -------- - ms_object: nd.array, optional - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - padding: int, optional - Padding to leave uncropped - num_cols: int, optional - Number of GridSpec columns - show_fft: bool, optional - if True, plots fft of object slices - """ - - if ms_object is None: - ms_object = self._object - - rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) - if show_fft: - rotated_object = np.abs( - np.fft.fftshift( - np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) - ) - ) - rotated_shape = rotated_object.shape - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - - extent = [ - 0, - self.sampling[1] * rotated_shape[2], - self.sampling[0] * rotated_shape[1], - 0, - ] - - num_rows = np.ceil(self._num_slices / num_cols).astype("int") - wspace = 0.35 if cbar else 0.15 - - axsize = kwargs.pop("axsize", (3, 3)) - cmap = kwargs.pop("cmap", "magma") - - if common_color_scale: - vals = np.sort(rotated_object.ravel()) - ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") - ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") - ind_vmin = np.max([0, ind_vmin]) - ind_vmax = np.min([len(vals) - 1, ind_vmax]) - vmin = vals[ind_vmin] - vmax = vals[ind_vmax] - if vmax == vmin: - vmin = vals[0] - vmax = vals[-1] - else: - vmax = None - vmin = None - vmin = kwargs.pop("vmin", vmin) - vmax = kwargs.pop("vmax", vmax) - - spec = GridSpec( - ncols=num_cols, - nrows=num_rows, - hspace=0.15, - wspace=wspace, - ) - - figsize = (axsize[0] * num_cols, axsize[1] * num_rows) - fig = plt.figure(figsize=figsize) - - for flat_index, obj_slice in enumerate(rotated_object): - row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) - ax = fig.add_subplot(spec[row_index, col_index]) - im = ax.imshow( - obj_slice, - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - ax.set_title(f"Slice index: {flat_index}") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if row_index < num_rows - 1: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col_index > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - spec.tight_layout(fig) - - def show_depth( - self, - x1: float, - x2: float, - y1: float, - y2: float, - specify_calibrated: bool = False, - gaussian_filter_sigma: float = None, - ms_object=None, - cbar: bool = False, - aspect: float = None, - plot_line_profile: bool = False, - **kwargs, - ): - """ - Displays line profile depth section - - Parameters - -------- - x1, x2, y1, y2: floats (pixels) - Line profile for depth section runs from (x1,y1) to (x2,y2) - Specified in pixels unless specify_calibrated is True - specify_calibrated: bool (optional) - If True, specify x1, x2, y1, y2 in A values instead of pixels - gaussian_filter_sigma: float (optional) - Standard deviation of gaussian kernel in A - ms_object: np.array - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - aspect: float, optional - aspect ratio for depth profile plot - plot_line_profile: bool - If True, also plots line profile showing where depth profile is taken - """ - if ms_object is not None: - ms_obj = ms_object - else: - ms_obj = self.object_cropped - - if specify_calibrated: - x1 /= self.sampling[0] - x2 /= self.sampling[0] - y1 /= self.sampling[1] - y2 /= self.sampling[1] - - if x2 == x1: - angle = 0 - elif y2 == y1: - angle = np.pi / 2 - else: - angle = np.arctan((x2 - x1) / (y2 - y1)) - - x0 = ms_obj.shape[1] / 2 - y0 = ms_obj.shape[2] / 2 - - if ( - x1 > ms_obj.shape[1] - or x2 > ms_obj.shape[1] - or y1 > ms_obj.shape[2] - or y2 > ms_obj.shape[2] - ): - raise ValueError("depth section must be in field of view of object") - - from py4DSTEM.process.phase.utils import rotate_point - - x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) - x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) - - rotated_object = np.roll( - rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - int(x1_0), - axis=1, - ) - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - if gaussian_filter_sigma is not None: - from scipy.ndimage import gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - - plot_im = rotated_object[ - :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) - ] - - extent = [ - 0, - self.sampling[1] * plot_im.shape[1], - self._slice_thicknesses[0] * plot_im.shape[0], - 0, - ] - - figsize = kwargs.pop("figsize", (6, 6)) - if not plot_line_profile: - fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("r [A]") - ax.set_ylabel("z [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - else: - extent2 = [ - 0, - self.sampling[1] * ms_obj.shape[2], - self.sampling[0] * ms_obj.shape[1], - 0, - ] - fig, ax = plt.subplots(2, 1, figsize=figsize) - ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) - ax[0].plot( - [y1 * self.sampling[0], y2 * self.sampling[1]], - [x1 * self.sampling[0], x2 * self.sampling[1]], - color="red", - ) - ax[0].set_xlabel("y [A]") - ax[0].set_ylabel("x [A]") - ax[0].set_title("Multislice depth profile location") - - im = ax[1].imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax[1].set_aspect(aspect) - ax[1].set_xlabel("r [A]") - ax[1].set_ylabel("z [A]") - ax[1].set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax[1]) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - plt.tight_layout() - - def tune_num_slices_and_thicknesses( - self, - num_slices_guess=None, - thicknesses_guess=None, - num_slices_step_size=1, - thicknesses_step_size=20, - num_slices_values=3, - num_thicknesses_values=3, - update_defocus=False, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, - ): - """ - Run reconstructions over a parameters space of number of slices - and slice thicknesses. Should be run after the preprocess step. - - Parameters - ---------- - num_slices_guess: float, optional - initial starting guess for number of slices, rounds to nearest integer - if None, uses current initialized values - thicknesses_guess: float (A), optional - initial starting guess for thicknesses of slices assuming same - thickness for each slice - if None, uses current initialized values - num_slices_step_size: float, optional - size of change of number of slices for each step in parameter space - thicknesses_step_size: float (A), optional - size of change of slice thicknesses for each step in parameter space - num_slices_values: int, optional - number of number of slice values to test, must be >= 1 - num_thicknesses_values: int,optional - number of thicknesses values to test, must be >= 1 - update_defocus: bool, optional - if True, updates defocus based on estimated total thickness - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - - # calculate number of slices and thicknesses values to test - if num_slices_guess is None: - num_slices_guess = self._num_slices - if thicknesses_guess is None: - thicknesses_guess = np.mean(self._slice_thicknesses) - - if num_slices_values == 1: - num_slices_step_size = 0 - - if num_thicknesses_values == 1: - thicknesses_step_size = 0 - - num_slices = np.linspace( - num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_values, - ) - - thicknesses = np.linspace( - thicknesses_guess - - thicknesses_step_size * (num_thicknesses_values - 1) / 2, - thicknesses_guess - + thicknesses_step_size * (num_thicknesses_values - 1) / 2, - num_thicknesses_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_num_slices = self._num_slices - current_thicknesses = self._slice_thicknesses - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - current_defocus = -self._polar_parameters["C10"] - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values * 2, - height_ratios=[1, 1 / 4] * num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) - ) - else: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) - ) - - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (slices, thickness) in enumerate( - tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") - ): - slices = int(slices) - self._num_slices = slices - self._slice_thicknesses = np.tile(thickness, slices - 1) - self._probe = None - self._object = None - if update_defocus: - defocus = current_defocus + slices / 2 * thickness - self._polar_parameters["C10"] = -defocus - - self.preprocess( - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - ) - self.reconstruct( - reset=True, - store_iterations=True if plot_convergence else False, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_slices_values, num_thicknesses_values) - ) - - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._probe = None - self._object = None - self._num_slices = current_num_slices - self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) - self._polar_parameters["C10"] = -current_defocus - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = asnumpy(obj) - if np.iscomplexobj(obj): - obj = np.angle(obj) - - obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped).sum(0) - else: - projected_cropped_potential = self.object_cropped.sum(0) - - return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py deleted file mode 100644 index 880858f30..000000000 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ /dev/null @@ -1,2430 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely mixed-state ptychography. -""" - -import warnings -from typing import Mapping, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, -) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class MixedstatePtychographicReconstruction(PtychographicReconstruction): - """ - Mixed-State Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (N,Sx,Sy) - Reconstructed object dimensions : (Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes - and (Px,Py) is the padded-object size we position our ROI around in. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - datacube: DataCube - Input 4D diffraction pattern intensities - num_probes: int, optional - Number of mixed-state probes - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_probes",) - - def __init__( - self, - energy: float, - datacube: DataCube = None, - num_probes: int = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "mixed-state_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): - if num_probes is None: - raise ValueError( - ( - "If initial_probe_guess is None, or a ComplexProbe object, " - "num_probes must be specified." - ) - ) - else: - if len(initial_probe_guess.shape) != 3: - raise ValueError( - "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." - ) - num_probes = initial_probe_guess.shape[0] - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_probes = num_probes - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None or isinstance(self._probe, ComplexProbe): - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - _probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - else: - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - _probe = self._probe._array - else: - self._probe._xp = xp - _probe = self._probe.build()._array - - self._probe = xp.zeros( - (self._num_probes,) + tuple(self._region_of_interest_shape), - dtype=xp.complex64, - ) - sx, sy = self._region_of_interest_shape - self._probe[0] = _probe - - # Randomly shift phase of other probes - for i_probe in range(1, self._num_probes): - shift_x = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) - ) - shift_y = xp.exp( - -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) - ) - self._probe[i_probe] = ( - self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = None # Doesn't really make sense for mixed-state - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # overlaps - shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, axs = plt.subplots(1, self._num_probes + 1, figsize=figsize) - - for i in range(self._num_probes): - axs[i].imshow( - complex_probe_rgb[i], - extent=probe_extent, - ) - axs[i].set_ylabel("x [A]") - axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial probe[{i}] intensity") - - divider = make_axes_locatable(axs[i]) - cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax, chroma_boost=chroma_boost) - - axs[-1].imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - axs[-1].scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - axs[-1].set_ylabel("x [A]") - axs[-1].set_xlabel("y [A]") - axs[-1].set_xlim((extent[0], extent[1])) - axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - overlap = shifted_probes * xp.expand_dims(object_patches, axis=1) - - return shifted_probes, object_patches, overlap - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - intensity_norm[intensity_norm == 0.0] = np.inf - amplitude_modification = amplitudes / intensity_norm - - fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_overlap - overlap - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = overlap.copy() - - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) - - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - intensity_norm_projected = xp.sqrt( - xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) - ) - intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf - - amplitude_modification = amplitudes / intensity_norm_projected - fourier_projected_factor *= amplitude_modification[:, None] - - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * overlap - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = xp.zeros_like(current_object) - object_update = xp.zeros_like(current_object) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 - ) - if self._object_type == "potential": - object_update += step_size * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - else: - object_update += step_size * self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object += object_update * probe_normalization - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = xp.zeros_like(current_object) - current_object = xp.zeros_like(current_object) - - for i_probe in range(self._num_probes): - probe_normalization += self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes[:, i_probe]) ** 2 - ) - if self._object_type == "potential": - current_object += self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes[:, i_probe]) - * exit_waves[:, i_probe] - ) - ) - else: - current_object += self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe] - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object *= probe_normalization - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, - axis=0, - ) - * object_normalization[None] - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - probe_intensity = xp.abs(current_probe[0]) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_orthogonalization_constraint(self, current_probe): - """ - Ptychographic probe-orthogonalization constraint. - Used to ensure mixed states are orthogonal to each other. - Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Orthogonalized probe estimate - """ - xp = self._xp - n_probes = self._num_probes - - # compute upper half of P* @ P - pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) - - for i in range(n_probes): - for j in range(i, n_probes): - pairwise_dot_product[i, j] = xp.sum( - current_probe[i].conj() * current_probe[j] - ) - - # compute eigenvectors (effectively cheaper way of computing V* from SVD) - _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") - current_probe = xp.tensordot(evecs.T, current_probe, axes=1) - - # sort by real-space intensity - intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) - intensities_order = xp.argsort(intensities, axis=None)[::-1] - return current_probe[intensities_order] - - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - orthogonalize_probe, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray - initial probe aperture to use in replacing probe fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - orthogonalize_probe: bool - If True, probe will be orthogonalized - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - # These constraints don't _really_ make sense for mixed-state - if fix_probe_aperture: - raise NotImplementedError() - elif constrain_probe_fourier_amplitude: - raise NotImplementedError() - if fit_probe_aberrations: - raise NotImplementedError() - if constrain_probe_amplitude: - raise NotImplementedError() - - if orthogonalize_probe: - current_probe = self._probe_orthogonalization_constraint(current_probe) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, - fix_com: bool = True, - orthogonalize_probe: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - global_affine_transformation: bool = True, - constrain_position_distance: float = None, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, - tv_denoise_weight: float = None, - tv_denoise_inner_iter: float = 40, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - shifted_probes, - object_patches, - overlap, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - shifted_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - shifted_probes[:, 0], - overlap[:, 0], - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - orthogonalize_probe=orthogonalize_probe, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_inner_iter=tv_denoise_inner_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - ax = fig.add_subplot(spec[0, 1]) - - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual[0] - else: - probe_array = self.probe_fourier[0] - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe[0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append(self._crop_rotate_object_fov(obj, padding=padding)) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]][0], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - - return self - - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - if probe is None: - probe = list( - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - ) - else: - if isinstance(probe, np.ndarray) and probe.ndim == 2: - probe = [probe] - probe = [ - asnumpy( - self._return_fourier_probe( - pr, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for pr in probe - ] - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - chroma_boost = kwargs.pop("chroma_boost", 1) - - show_complex( - probe if len(probe) > 1 else probe[0], - cbar=cbar, - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py deleted file mode 100644 index 39cb62fdd..000000000 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ /dev/null @@ -1,3465 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely multislice ptychography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple, Union - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar -from scipy.ndimage import rotate - -warnings.simplefilter(action="always", category=UserWarning) - - -class MultislicePtychographicReconstruction(PtychographicReconstruction): - """ - Multislice Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (T,Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py) is the padded-object size we position our ROI around in - each of the T slices. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - num_slices: int - Number of slices to use in the forward model - slice_thicknesses: float or Sequence[float] - Slice thicknesses in angstroms. If float, all slices are assigned the same thickness - datacube: DataCube, optional - Input 4D diffraction pattern intensities - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - middle_focus: bool - if True, adds half the sample thickness to the defocus - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_slice_thicknesses") - - def __init__( - self, - energy: float, - num_slices: int, - slice_thicknesses: Union[float, Sequence[float]], - datacube: DataCube = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - theta_x: float = 0, - theta_y: float = 0, - middle_focus: bool = False, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "multi-slice_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - if np.isscalar(slice_thicknesses): - mean_slice_thickness = slice_thicknesses - else: - mean_slice_thickness = np.mean(slice_thicknesses) - - if middle_focus: - if "defocus" in kwargs: - kwargs["defocus"] += mean_slice_thickness * num_slices / 2 - elif "C10" in kwargs: - kwargs["C10"] -= mean_slice_thickness * num_slices / 2 - elif polar_parameters is not None and "defocus" in polar_parameters: - polar_parameters["defocus"] = ( - polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 - ) - elif polar_parameters is not None and "C10" in polar_parameters: - polar_parameters["C10"] = ( - polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 - ) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - slice_thicknesses = np.array(slice_thicknesses) - if slice_thicknesses.shape == (): - slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) - elif slice_thicknesses.shape[0] != (num_slices - 1): - raise ValueError( - ( - f"slice_thicknesses must have length {num_slices - 1}, " - f"not {slice_thicknesses.shape[0]}." - ) - ) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._positions_mask = positions_mask - self._object_padding_px = object_padding_px - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_slices = num_slices - self._slice_thicknesses = slice_thicknesses - self._theta_x = theta_x - self._theta_y = theta_y - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - theta_x: float, - theta_y: float, - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - theta_x: float - x tilt of propagator (in degrees) - theta_y: float - y tilt of propagator (in degrees) - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - - theta_x = np.deg2rad(theta_x) - theta_y = np.deg2rad(theta_y) - - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) - ) - propagators[i] *= xp.exp( - 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (T,Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: MultislicePtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - self._theta_x, - self._theta_y, - ) - - # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe.copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1, chroma_boost=chroma_boost) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, - chroma_boost=chroma_boost, - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts(xp.conj(probe) * exit_waves) - * probe_normalization - ) - - # back-transmit - exit_waves *= xp.conj(obj) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.conj(probe) * exit_waves_copy - ) - * probe_normalization - ) - - # back-transmit - exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - 2D Butterworth filter - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_kz_regularization_constraint( - self, current_object, kz_regularization_gamma - ): - """ - Arctan regularization filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - kz_regularization_gamma: float - Slice regularization strength - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - current_object = xp.pad( - current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" - ) - - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) - - kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] - - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qz2 = qza**2 * kz_regularization_gamma**2 - qr2 = qxa**2 + qya**2 - - w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) - - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) - current_object = current_object[1:] - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_identical_slices_constraint(self, current_object): - """ - Strong regularization forcing all slices to be identical - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - object_mean = current_object.mean(0, keepdims=True) - current_object[:] = object_mean - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - kz_regularization_filter, - kz_regularization_gamma, - identical_slices, - object_positivity, - shrinkage_rad, - object_mask, - pure_phase_object, - tv_denoise_chambolle, - tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool - If True, probe Fourier amplitude is replaced by initial_probe_aperture - initial_probe_aperture: np.ndarray - Initial probe aperture to use in replacing probe Fourier amplitude - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter in A - gaussian_filter_sigma: float - Standard deviation of gaussian kernel - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter: bool - If True, applies fourier-space arctan regularization filter - kz_regularization_gamma: float - Slice regularization strength - identical_slices: bool - If True, forces all object slices to be identical - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - pure_phase_object: bool - If True, object amplitude is set to unity - tv_denoise_chambolle: bool - If True, performs TV denoising along z - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if identical_slices: - current_object = self._object_identical_slices_constraint(current_object) - elif kz_regularization_filter: - current_object = self._object_kz_regularization_constraint( - current_object, kz_regularization_gamma - ) - elif tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - elif tv_denoise_chambolle: - current_object = self._object_denoise_tv_chambolle( - current_object, - tv_denoise_weight_chambolle, - axis=0, - pad_object=tv_denoise_pad_chambolle, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - kz_regularization_filter_iter: int = np.inf, - kz_regularization_gamma: Union[float, np.ndarray] = None, - identical_slices_iter: int = 0, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - pure_phase_object_iter: int = 0, - tv_denoise_iter_chambolle=np.inf, - tv_denoise_weight_chambolle=None, - tv_denoise_pad_chambolle=True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - kz_regularization_filter_iter: int, optional - Number of iterations to run using kz regularization filter - kz_regularization_gamma, float, optional - kz regularization strength - identical_slices_iter: int, optional - Number of iterations to run using identical slices - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - tv_denoise_iter_chambolle: bool - Number of iterations with TV denoisining - tv_denoise_weight_chambolle: float - weight of tv denoising constraint - tv_denoise_pad_chambolle: bool - if True, pads object at top and bottom with zeros before applying denoising - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: MultislicePtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - propagated_probes, - object_patches, - self._transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - self._probe, - self._transmitted_probes, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - kz_regularization_filter=a0 < kz_regularization_filter_iter - and kz_regularization_gamma is not None, - kz_regularization_gamma=kz_regularization_gamma[a0] - if kz_regularization_gamma is not None - and isinstance(kz_regularization_gamma, np.ndarray) - else kz_regularization_gamma, - identical_slices=a0 < identical_slices_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle - and tv_denoise_weight_chambolle is not None, - tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, - tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - padding : int, optional - Pixels to pad by post rotating-cropping object - - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov( - np.sum(obj, axis=0), padding=padding - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append( - self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) - ) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]], power=2, chroma_boost=chroma_boost - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - return self - - def show_transmitted_probe( - self, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations=False, - **kwargs, - ): - """ - Plots the min, max, and mean transmitted probe after propagation and transmission. - - Parameters - ---------- - plot_fourier_probe: boolean, optional - If True, the transmitted probes are also plotted in Fourier space - kwargs: - Passed to show_complex - """ - - xp = self._xp - asnumpy = self._asnumpy - - transmitted_probe_intensities = xp.sum( - xp.abs(self._transmitted_probes) ** 2, axis=(-2, -1) - ) - min_intensity_transmitted = self._transmitted_probes[ - xp.argmin(transmitted_probe_intensities) - ] - max_intensity_transmitted = self._transmitted_probes[ - xp.argmax(transmitted_probe_intensities) - ] - mean_transmitted = self._transmitted_probes.mean(0) - probes = [ - asnumpy(self._return_centered_probe(probe)) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - title = [ - "Mean Transmitted Probe", - "Min Intensity Transmitted Probe", - "Max Intensity Transmitted Probe", - ] - - if plot_fourier_probe: - bottom_row = [ - asnumpy( - self._return_fourier_probe( - probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - for probe in [ - mean_transmitted, - min_intensity_transmitted, - max_intensity_transmitted, - ] - ] - probes = [probes, bottom_row] - - title += [ - "Mean Transmitted Fourier Probe", - "Min Intensity Transmitted Fourier Probe", - "Max Intensity Transmitted Fourier Probe", - ] - - title = kwargs.get("title", title) - show_complex( - probes, - title=title, - **kwargs, - ) - - def show_slices( - self, - ms_object=None, - cbar: bool = True, - common_color_scale: bool = True, - padding: int = 0, - num_cols: int = 3, - show_fft: bool = False, - **kwargs, - ): - """ - Displays reconstructed slices of object - - Parameters - -------- - ms_object: nd.array, optional - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - padding: int, optional - Padding to leave uncropped - num_cols: int, optional - Number of GridSpec columns - show_fft: bool, optional - if True, plots fft of object slices - """ - - if ms_object is None: - ms_object = self._object - - rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) - if show_fft: - rotated_object = np.abs( - np.fft.fftshift( - np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) - ) - ) - rotated_shape = rotated_object.shape - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - - extent = [ - 0, - self.sampling[1] * rotated_shape[2], - self.sampling[0] * rotated_shape[1], - 0, - ] - - num_rows = np.ceil(self._num_slices / num_cols).astype("int") - wspace = 0.35 if cbar else 0.15 - - axsize = kwargs.pop("axsize", (3, 3)) - cmap = kwargs.pop("cmap", "magma") - - if common_color_scale: - vals = np.sort(rotated_object.ravel()) - ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") - ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") - ind_vmin = np.max([0, ind_vmin]) - ind_vmax = np.min([len(vals) - 1, ind_vmax]) - vmin = vals[ind_vmin] - vmax = vals[ind_vmax] - if vmax == vmin: - vmin = vals[0] - vmax = vals[-1] - else: - vmax = None - vmin = None - vmin = kwargs.pop("vmin", vmin) - vmax = kwargs.pop("vmax", vmax) - - spec = GridSpec( - ncols=num_cols, - nrows=num_rows, - hspace=0.15, - wspace=wspace, - ) - - figsize = (axsize[0] * num_cols, axsize[1] * num_rows) - fig = plt.figure(figsize=figsize) - - for flat_index, obj_slice in enumerate(rotated_object): - row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) - ax = fig.add_subplot(spec[row_index, col_index]) - im = ax.imshow( - obj_slice, - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - ax.set_title(f"Slice index: {flat_index}") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if row_index < num_rows - 1: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col_index > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - spec.tight_layout(fig) - - def show_depth( - self, - x1: float, - x2: float, - y1: float, - y2: float, - specify_calibrated: bool = False, - gaussian_filter_sigma: float = None, - ms_object=None, - cbar: bool = False, - aspect: float = None, - plot_line_profile: bool = False, - **kwargs, - ): - """ - Displays line profile depth section - - Parameters - -------- - x1, x2, y1, y2: floats (pixels) - Line profile for depth section runs from (x1,y1) to (x2,y2) - Specified in pixels unless specify_calibrated is True - specify_calibrated: bool (optional) - If True, specify x1, x2, y1, y2 in A values instead of pixels - gaussian_filter_sigma: float (optional) - Standard deviation of gaussian kernel in A - ms_object: np.array - Object to plot slices of. If None, uses current object - cbar: bool, optional - If True, displays a colorbar - aspect: float, optional - aspect ratio for depth profile plot - plot_line_profile: bool - If True, also plots line profile showing where depth profile is taken - """ - if ms_object is not None: - ms_obj = ms_object - else: - ms_obj = self.object_cropped - - if specify_calibrated: - x1 /= self.sampling[0] - x2 /= self.sampling[0] - y1 /= self.sampling[1] - y2 /= self.sampling[1] - - if x2 == x1: - angle = 0 - elif y2 == y1: - angle = np.pi / 2 - else: - angle = np.arctan((x2 - x1) / (y2 - y1)) - - x0 = ms_obj.shape[1] / 2 - y0 = ms_obj.shape[2] / 2 - - if ( - x1 > ms_obj.shape[1] - or x2 > ms_obj.shape[1] - or y1 > ms_obj.shape[2] - or y2 > ms_obj.shape[2] - ): - raise ValueError("depth section must be in field of view of object") - - from py4DSTEM.process.phase.utils import rotate_point - - x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) - x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) - - rotated_object = np.roll( - rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - -int(x1_0), - axis=1, - ) - - if np.iscomplexobj(rotated_object): - rotated_object = np.angle(rotated_object) - if gaussian_filter_sigma is not None: - from scipy.ndimage import gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - - plot_im = rotated_object[ - :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) - ] - - extent = [ - 0, - self.sampling[1] * plot_im.shape[1], - self._slice_thicknesses[0] * plot_im.shape[0], - 0, - ] - figsize = kwargs.pop("figsize", (6, 6)) - if not plot_line_profile: - fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("r [A]") - ax.set_ylabel("z [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - else: - extent2 = [ - 0, - self.sampling[1] * ms_obj.shape[2], - self.sampling[0] * ms_obj.shape[1], - 0, - ] - - fig, ax = plt.subplots(2, 1, figsize=figsize) - ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) - ax[0].plot( - [y1 * self.sampling[0], y2 * self.sampling[1]], - [x1 * self.sampling[0], x2 * self.sampling[1]], - color="red", - ) - ax[0].set_xlabel("y [A]") - ax[0].set_ylabel("x [A]") - ax[0].set_title("Multislice depth profile location") - - im = ax[1].imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax[1].set_aspect(aspect) - ax[1].set_xlabel("r [A]") - ax[1].set_ylabel("z [A]") - ax[1].set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax[1]) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - plt.tight_layout() - - def tune_num_slices_and_thicknesses( - self, - num_slices_guess=None, - thicknesses_guess=None, - num_slices_step_size=1, - thicknesses_step_size=20, - num_slices_values=3, - num_thicknesses_values=3, - update_defocus=False, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, - ): - """ - Run reconstructions over a parameters space of number of slices - and slice thicknesses. Should be run after the preprocess step. - - Parameters - ---------- - num_slices_guess: float, optional - initial starting guess for number of slices, rounds to nearest integer - if None, uses current initialized values - thicknesses_guess: float (A), optional - initial starting guess for thicknesses of slices assuming same - thickness for each slice - if None, uses current initialized values - num_slices_step_size: float, optional - size of change of number of slices for each step in parameter space - thicknesses_step_size: float (A), optional - size of change of slice thicknesses for each step in parameter space - num_slices_values: int, optional - number of number of slice values to test, must be >= 1 - num_thicknesses_values: int,optional - number of thicknesses values to test, must be >= 1 - update_defocus: bool, optional - if True, updates defocus based on estimated total thickness - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - - # calculate number of slices and thicknesses values to test - if num_slices_guess is None: - num_slices_guess = self._num_slices - if thicknesses_guess is None: - thicknesses_guess = np.mean(self._slice_thicknesses) - - if num_slices_values == 1: - num_slices_step_size = 0 - - if num_thicknesses_values == 1: - thicknesses_step_size = 0 - - num_slices = np.linspace( - num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, - num_slices_values, - ) - - thicknesses = np.linspace( - thicknesses_guess - - thicknesses_step_size * (num_thicknesses_values - 1) / 2, - thicknesses_guess - + thicknesses_step_size * (num_thicknesses_values - 1) / 2, - num_thicknesses_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_num_slices = self._num_slices - current_thicknesses = self._slice_thicknesses - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - current_defocus = -self._polar_parameters["C10"] - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values * 2, - height_ratios=[1, 1 / 4] * num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) - ) - else: - spec = GridSpec( - ncols=num_thicknesses_values, - nrows=num_slices_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) - ) - - fig = plt.figure(figsize=figsize) - - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (slices, thickness) in enumerate( - tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") - ): - slices = int(slices) - self._num_slices = slices - self._slice_thicknesses = np.tile(thickness, slices - 1) - self._probe = None - self._object = None - if update_defocus: - defocus = current_defocus + slices / 2 * thickness - self._polar_parameters["C10"] = -defocus - - self.preprocess( - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - ) - self.reconstruct( - reset=True, - store_iterations=True if plot_convergence else False, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_slices_values, num_thicknesses_values) - ) - - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._probe = None - self._object = None - self._num_slices = current_num_slices - self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) - self._polar_parameters["C10"] = -current_defocus - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = asnumpy(obj) - if np.iscomplexobj(obj): - obj = np.angle(obj) - - obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped).sum(0) - else: - projected_cropped_potential = self.object_cropped.sum(0) - - return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py deleted file mode 100644 index 670ea5e40..000000000 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ /dev/null @@ -1,3389 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely overlap magnetic tomography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.visualize import show -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg -from scipy.ndimage import rotate as rotate_np - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - project_vector_field_divergence, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): - """ - Overlap Magnetic Tomographic Reconstruction Class. - - List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (Px,Py,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py,Py) is the padded-object electrostatic potential volume, - where x-axis is the tilt. - - Parameters - ---------- - datacube: List of DataCubes - Input list of 4D diffraction pattern intensities for different tilts - energy: float - The electron energy of the wave functions in eV - num_slices: int - Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of (\alpha, \beta) tilt angle tuple in degrees, - with the following Euler-angle convention: - - \alpha tilt around z-axis - - \beta tilt around x-axis - - -\alpha tilt around z-axis - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py,Py) - If None, initialized to 1.0 - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: list of np.ndarray, optional - Probe positions in Å for each diffraction intensity per tilt - If None, initialized to a grid scan centered along tilt axis - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") - - def __init__( - self, - energy: float, - num_slices: int, - tilt_angles_deg: Sequence[Tuple[float, float]], - datacube: Sequence[DataCube] = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - object_type: str = "potential", - positions_mask: np.ndarray = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: Sequence[np.ndarray] = None, - verbose: bool = True, - device: str = "cpu", - name: str = "overlap-magnetic-tomographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - num_tilts = len(tilt_angles_deg) - if initial_scan_positions is None: - initial_scan_positions = [None] * num_tilts - - if object_type != "potential": - raise NotImplementedError() - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) - self._num_tilts = num_tilts - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def _project_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(input_z / output_z).astype("int") - pad_size = voxels_per_slice * output_z - input_z - - padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) - - return xp.sum( - padded_array.reshape( - ( - -1, - voxels_per_slice, - ) - + array.shape[1:] - ), - axis=1, - ) - - def _expand_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(output_z / input_z).astype("int") - remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) - - voxels_in_slice = xp.repeat(voxels_per_slice, input_z) - voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice - - normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] - return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] - - def _euler_angle_rotate_volume( - self, - volume_array, - alpha_deg, - beta_deg, - ): - """ - Rotate 3D volume using alpha, beta, gamma Euler angles according to convention: - - - \\-alpha tilt around first axis (z) - - \\beta tilt around second axis (x) - - \\alpha tilt around first axis (z) - - Note: since we store array as zxy, the x- and y-axis rotations flip sign below. - - """ - - rotate = self._rotate - volume = volume_array.copy() - - alpha_deg, beta_deg = np.mod(np.array([alpha_deg, beta_deg]) + 180, 360) - 180 - - if alpha_deg == -180: - # print(f"rotation of {-beta_deg} around x") - volume = rotate( - volume, - beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - elif alpha_deg == -90: - # print(f"rotation of {beta_deg} around y") - volume = rotate( - volume, - -beta_deg, - axes=(0, 1), - reshape=False, - order=3, - ) - elif alpha_deg == 0: - # print(f"rotation of {beta_deg} around x") - volume = rotate( - volume, - -beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - elif alpha_deg == 90: - # print(f"rotation of {-beta_deg} around y") - volume = rotate( - volume, - beta_deg, - axes=(0, 1), - reshape=False, - order=3, - ) - else: - # print(( - # f"rotation of {-alpha_deg} around z, " - # f"rotation of {beta_deg} around x, " - # f"rotation of {alpha_deg} around z." - # )) - - volume = rotate( - volume, - -alpha_deg, - axes=(1, 2), - reshape=False, - order=3, - ) - - volume = rotate( - volume, - -beta_deg, - axes=(0, 2), - reshape=False, - order=3, - ) - - volume = rotate( - volume, - alpha_deg, - axes=(1, 2), - reshape=False, - order=3, - ) - - return volume - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_probe_overlaps: bool = True, - rotation_real_space_degrees: float = None, - diffraction_patterns_rotate_degrees: float = None, - diffraction_patterns_transpose: bool = None, - force_com_shifts: Sequence[float] = None, - progress_bar: bool = True, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - - Additionally, it initializes an (Px,Py, Py) array of 1.0 - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - rotation_real_space_degrees: float (degrees), optional - In plane rotation around z axis between x axis and tilt axis in - real space (forced to be in xy plane) - diffraction_patterns_rotate_degrees: float, optional - Relative rotation angle between real and reciprocal space - diffraction_patterns_transpose: bool, optional - Whether diffraction intensities need to be transposed. - force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. One tuple per tilt. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) - - if self._positions_mask.ndim == 2: - warnings.warn( - "2D `positions_mask` assumed the same for all measurements.", - UserWarning, - ) - self._positions_mask = np.tile( - self._positions_mask, (self._num_tilts, 1, 1) - ) - - if self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array."), - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_tilts - - # Prepopulate various arrays - - if self._positions_mask[0] is None: - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - - num_probes_per_tilt = np.array(num_probes_per_tilt) - else: - num_probes_per_tilt = np.insert( - self._positions_mask.sum(axis=(-2, -1)), 0, 0 - ) - - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) - - self._mean_diffraction_intensity = [] - self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) - - self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) - self._rotation_best_transpose = diffraction_patterns_transpose - - if force_com_shifts is None: - force_com_shifts = [None] * self._num_tilts - - for tilt_index in tqdmnd( - self._num_tilts, - desc="Preprocessing data", - unit="tilt", - disable=not progress_bar, - ): - if tilt_index == 0: - ( - self._datacube[tilt_index], - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts[tilt_index], - ) - - self._amplitudes = xp.empty( - (self._num_diffraction_patterns,) + self._datacube[0].Qshape - ) - self._region_of_interest_shape = np.array( - self._amplitudes[0].shape[-2:] - ) - - else: - ( - self._datacube[tilt_index], - _, - _, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[tilt_index], - ) - - intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube[tilt_index], - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[tilt_index], - ) - - ( - self._amplitudes[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ], - mean_diffraction_intensity_temp, - ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, - crop_patterns, - self._positions_mask[tilt_index], - ) - - self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) - - del ( - intensities, - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index], self._positions_mask[tilt_index] - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px_all, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - self._object = xp.zeros((4, q, p, q), dtype=xp.float32) - else: - self._object = xp.asarray(self._object, dtype=xp.float32) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - self._num_voxels = self._object.shape[1] - - # Center Probes - self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) - - for tilt_index in range(self._num_tilts): - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= ( - self._positions_px_com - xp.array(self._object_shape) / 2 - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._positions_px.copy() - - self._positions_px_initial_all = self._positions_px_all.copy() - self._positions_initial_all = self._positions_px_initial_all.copy() - self._positions_initial_all[:, 0] *= self.sampling[0] - self._positions_initial_all[:, 1] *= self.sampling[1] - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._slice_thicknesses = np.tile( - self._object_shape[1] * self.sampling[1] / self._num_slices, - self._num_slices - 1, - ) - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - ) - - # overlaps - if object_fov_mask is None: - probe_overlap_3D = xp.zeros_like(self._object[0]) - - for tilt_index in np.arange(self._num_tilts): - alpha_deg, beta_deg = self._tilt_angles_deg[tilt_index] - - probe_overlap_3D = self._euler_angle_rotate_volume( - probe_overlap_3D, - alpha_deg, - beta_deg, - ) - - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probe, self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts( - probe_intensities - ) - - probe_overlap_3D += probe_overlap[None] - - probe_overlap_3D = self._euler_angle_rotate_volume( - probe_overlap_3D, - alpha_deg, - -beta_deg, - ) - - probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) - self._object_fov_mask = asnumpy( - probe_overlap_3D > 0.25 * probe_overlap_3D.max() - ) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe.copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, - chroma_boost=chroma_boost, - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[0, :, 1], - self.positions[0, :, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection( - self, current_object_V, current_object_A_projected, current_probe - ): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - complex_object = xp.exp(1j * (current_object_V + current_object_A_projected)) - object_patches = complex_object[ - :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object_V, - current_object_A_projected, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes:np.ndarray - Prop[object^n*probe^n] - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection( - current_object_V, - current_object_A_projected, - current_probe, - ) - - if use_projection_scheme: - ( - exit_waves[self._active_tilt_index], - error, - ) = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves[self._active_tilt_index], - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - object_update = step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - current_object_V[s] += object_update - current_object_A_projected[s] += object_update - - # back-transmit - exit_waves *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object_V, current_object_A_projected, current_probe - - def _projection_sets_adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - object_update = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - current_object_V[s] = object_update - current_object_A_projected[s] = object_update - - # back-transmit - exit_waves_copy *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object_V, current_object_A_projected, current_probe - - def _adjoint( - self, - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object_V: np.ndarray - Current electrostatic object estimate - current_object_A_projected: np.ndarray - Current projected magnetic object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object_V: np.ndarray - Updated electrostatic object estimate - updated_object_A_projected: np.ndarray - Updated projected magnetic object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - ( - current_object_V, - current_object_A_projected, - current_probe, - ) = self._projection_sets_adjoint( - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves[self._active_tilt_index], - normalization_min, - fix_probe, - ) - else: - ( - current_object_V, - current_object_A_projected, - current_probe, - ) = self._gradient_descent_adjoint( - current_object_V, - current_object_A_projected, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object_V, current_object_A_projected, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes at each layer - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes[-1]) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes[-1].shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - complex_object = xp.exp(1j * current_object) - - # dx - propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes - - # propagate - if s + 1 < self._num_slices: - propagated_probes = self._propagate_array( - transmitted_probes_perturbed[s], self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2( - transmitted_probes_perturbed[-1] - ) - - # dy - propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes - - # propagate - if s + 1 < self._num_slices: - propagated_probes = self._propagate_array( - transmitted_probes_perturbed[s], self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2( - transmitted_probes_perturbed[-1] - ) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - gaussian_filter = self._gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - Butterworth filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qra = xp.sqrt(qza**2 + qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) - current_object += current_object_mean - - return xp.real(current_object) - - def _divergence_free_constraint(self, vector_field): - """ - Leray projection operator - - Parameters - -------- - vector_field: np.ndarray - Current object vector as Az, Ax, Ay - - Returns - -------- - projected_vector_field: np.ndarray - Divergence-less object vector as Az, Ax, Ay - """ - xp = self._xp - - spacings = (self.sampling[1],) + self.sampling - vector_field = project_vector_field_divergence( - vector_field, spacings=spacings, xp=xp - ) - - return vector_field - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma_e, - gaussian_filter_sigma_m, - butterworth_filter, - q_lowpass_e, - q_lowpass_m, - q_highpass_e, - q_highpass_m, - butterworth_order, - object_positivity, - shrinkage_rad, - object_mask, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - Calls _threshold_object_constraint() and _probe_center_of_mass_constraint() - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object[0] = self._object_gaussian_constraint( - current_object[0], gaussian_filter_sigma_e - ) - current_object[1] = self._object_gaussian_constraint( - current_object[1], gaussian_filter_sigma_m - ) - current_object[2] = self._object_gaussian_constraint( - current_object[2], gaussian_filter_sigma_m - ) - current_object[3] = self._object_gaussian_constraint( - current_object[3], gaussian_filter_sigma_m - ) - - if butterworth_filter: - current_object[0] = self._object_butterworth_constraint( - current_object[0], - q_lowpass_e, - q_highpass_e, - butterworth_order, - ) - current_object[1] = self._object_butterworth_constraint( - current_object[1], - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - current_object[2] = self._object_butterworth_constraint( - current_object[2], - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - current_object[3] = self._object_butterworth_constraint( - current_object[3], - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - - elif tv_denoise: - current_object[0] = self._object_denoise_tv_pylops( - current_object[0], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[1] = self._object_denoise_tv_pylops( - current_object[1], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[2] = self._object_denoise_tv_pylops( - current_object[2], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - current_object[3] = self._object_denoise_tv_pylops( - current_object[3], - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object[0] = self._object_shrinkage_constraint( - current_object[0], - shrinkage_rad, - object_mask, - ) - - if object_positivity: - current_object[0] = self._object_positivity_constraint(current_object[0]) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma_e: float = None, - gaussian_filter_sigma_m: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass_e: float = None, - q_lowpass_m: float = None, - q_highpass_e: float = None, - q_highpass_m: float = None, - butterworth_order: float = 2, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - collective_tilt_updates: bool = False, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool, optional - If True, forces object to be positive - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - collective_tilt_updates: bool - if True perform collective tilt updates - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: OverlapMagneticTomographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Position Correction + Collective Updates not yet implemented - if fix_positions_iter < max_iter: - raise NotImplementedError( - "Position correction is currently incompatible with collective updates." - ) - - # Batching - - if max_batch_size is not None: - xp.random.seed(seed_random) - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px_all = self._positions_px_initial_all.copy() - if hasattr(self, "_tf"): - del self._tf - - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - - if gaussian_filter_sigma_m is None: - gaussian_filter_sigma_m = gaussian_filter_sigma_e - - if q_lowpass_m is None: - q_lowpass_m = q_lowpass_e - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if collective_tilt_updates: - collective_object = xp.zeros_like(self._object) - - tilt_indices = np.arange(self._num_tilts) - np.random.shuffle(tilt_indices) - - for tilt_index in tilt_indices: - tilt_error = 0.0 - self._active_tilt_index = tilt_index - - alpha_deg, beta_deg = self._tilt_angles_deg[self._active_tilt_index] - alpha, beta = np.deg2rad([alpha_deg, beta_deg]) - - # V - self._object[0] = self._euler_angle_rotate_volume( - self._object[0], - alpha_deg, - beta_deg, - ) - - # Az - self._object[1] = self._euler_angle_rotate_volume( - self._object[1], - alpha_deg, - beta_deg, - ) - - # Ax - self._object[2] = self._euler_angle_rotate_volume( - self._object[2], - alpha_deg, - beta_deg, - ) - - # Ay - self._object[3] = self._euler_angle_rotate_volume( - self._object[3], - alpha_deg, - beta_deg, - ) - - object_A = self._object[1] * np.cos(beta) + np.sin(beta) * ( - self._object[3] * np.cos(alpha) - self._object[2] * np.sin(alpha) - ) - - object_sliced_V = self._project_sliced_object( - self._object[0], self._num_slices - ) - - object_sliced_A = self._project_sliced_object( - object_A, self._num_slices - ) - - if not use_projection_scheme: - object_sliced_old_V = object_sliced_V.copy() - object_sliced_old_A = object_sliced_A.copy() - - start_tilt = self._cum_probes_per_tilt[self._active_tilt_index] - end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1] - - num_diffraction_patterns = end_tilt - start_tilt - shuffled_indices = np.arange(num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is None: - current_max_batch_size = num_diffraction_patterns - else: - current_max_batch_size = max_batch_size - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - - unshuffled_indices[shuffled_indices] = np.arange( - num_diffraction_patterns - ) - - positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[ - shuffled_indices - ] - initial_positions_px = self._positions_px_initial_all[ - start_tilt:end_tilt - ].copy()[shuffled_indices] - - for start, end in generate_batches( - num_diffraction_patterns, max_batch=current_max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_initial = initial_positions_px[start:end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - amplitudes = self._amplitudes[start_tilt:end_tilt][ - shuffled_indices[start:end] - ] - - # forward operator - ( - propagated_probes, - object_patches, - transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - object_sliced_V, - object_sliced_A, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - object_sliced_V, object_sliced_A, self._probe = self._adjoint( - object_sliced_V, - object_sliced_A, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - object_sliced_V, - self._probe, - transmitted_probes, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - tilt_error += batch_error - - if not use_projection_scheme: - object_sliced_V -= object_sliced_old_V - object_sliced_A -= object_sliced_old_A - - object_update_V = self._expand_sliced_object( - object_sliced_V, self._num_voxels - ) - object_update_A = self._expand_sliced_object( - object_sliced_A, self._num_voxels - ) - - if collective_tilt_updates: - collective_object[0] += self._euler_angle_rotate_volume( - object_update_V, - alpha_deg, - -beta_deg, - ) - collective_object[1] += self._euler_angle_rotate_volume( - object_update_A * np.cos(beta), - alpha_deg, - -beta_deg, - ) - collective_object[2] -= self._euler_angle_rotate_volume( - object_update_A * np.sin(alpha) * np.sin(beta), - alpha_deg, - -beta_deg, - ) - collective_object[3] += self._euler_angle_rotate_volume( - object_update_A * np.cos(alpha) * np.sin(beta), - alpha_deg, - -beta_deg, - ) - else: - self._object[0] += object_update_V - self._object[1] += object_update_A * np.cos(beta) - self._object[2] -= object_update_A * np.sin(alpha) * np.sin(beta) - self._object[3] += object_update_A * np.cos(alpha) * np.sin(beta) - - self._object[0] = self._euler_angle_rotate_volume( - self._object[0], - alpha_deg, - -beta_deg, - ) - - self._object[1] = self._euler_angle_rotate_volume( - self._object[1], - alpha_deg, - -beta_deg, - ) - - self._object[2] = self._euler_angle_rotate_volume( - self._object[2], - alpha_deg, - -beta_deg, - ) - - self._object[3] = self._euler_angle_rotate_volume( - self._object[3], - alpha_deg, - -beta_deg, - ) - - # Normalize Error - tilt_error /= ( - self._mean_diffraction_intensity[self._active_tilt_index] - * num_diffraction_patterns - ) - error += tilt_error - - # constraints - self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[ - unshuffled_indices - ] - - if not collective_tilt_updates: - ( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - ) = self._constraints( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma_m is not None, - gaussian_filter_sigma_e=gaussian_filter_sigma_e, - gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass_m is not None or q_highpass_m is not None), - q_lowpass_e=q_lowpass_e, - q_lowpass_m=q_lowpass_m, - q_highpass_e=q_highpass_e, - q_highpass_m=q_highpass_m, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter - and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - # Normalize Error Over Tilts - error /= self._num_tilts - - self._object[1:] = self._divergence_free_constraint(self._object[1:]) - - if collective_tilt_updates: - self._object += collective_object / self._num_tilts - - ( - self._object, - self._probe, - _, - ) = self._constraints( - self._object, - self._probe, - None, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=True, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma_m is not None, - gaussian_filter_sigma_e=gaussian_filter_sigma_e, - gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass_m is not None or q_highpass_m is not None), - q_lowpass_e=q_lowpass_e, - q_lowpass_m=q_lowpass_m, - q_highpass_e=q_highpass_e, - q_highpass_m=q_highpass_m, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _crop_rotate_object_manually( - self, - array, - angle, - x_lims, - y_lims, - ): - """ - Crops and rotates rotates object manually. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - angle: float - In-plane angle in degrees to rotate by - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - ------- - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - min_x, max_x = x_lims - min_y, max_y = y_lims - - if angle is not None: - rotated_array = rotate_np( - asnumpy(array), angle, reshape=False, axes=(-2, -1) - ) - else: - rotated_array = asnumpy(array) - - return rotated_array[..., min_x:max_x, min_y:max_y] - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - cmap = kwargs.pop("cmap", "magma") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object[0], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object[0] - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - figsize = kwargs.pop("figsize", (14, 10) if cbar else (12, 10)) - cmap_e = kwargs.pop("cmap_e", "magma") - cmap_m = kwargs.pop("cmap_m", "PuOr") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj_V = self._rotate( - self._object[0], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Az = self._rotate( - self._object[1], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Ax = self._rotate( - self._object[2], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_Ay = self._rotate( - self._object[3], - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - - rotated_3d_obj_V = asnumpy(rotated_3d_obj_V) - rotated_3d_obj_Az = asnumpy(rotated_3d_obj_Az) - rotated_3d_obj_Ax = asnumpy(rotated_3d_obj_Ax) - rotated_3d_obj_Ay = asnumpy(rotated_3d_obj_Ay) - else: - ( - rotated_3d_obj_V, - rotated_3d_obj_Az, - rotated_3d_obj_Ax, - rotated_3d_obj_Ay, - ) = self.object - - rotated_object_Vx = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Vy = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Vz = self._crop_rotate_object_manually( - rotated_3d_obj_V.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Azx = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Azy = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Azz = self._crop_rotate_object_manually( - rotated_3d_obj_Az.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Axx = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Axy = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Axz = self._crop_rotate_object_manually( - rotated_3d_obj_Ax.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_object_Ayx = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Ayy = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_object_Ayz = self._crop_rotate_object_manually( - rotated_3d_obj_Ay.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - rotated_shape = rotated_object_Vx.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - arrays = [ - [ - rotated_object_Vx, - rotated_object_Axx, - rotated_object_Ayx, - rotated_object_Azx, - ], - [ - rotated_object_Vy, - rotated_object_Axy, - rotated_object_Ayy, - rotated_object_Azy, - ], - [ - rotated_object_Vz, - rotated_object_Axz, - rotated_object_Ayz, - rotated_object_Azz, - ], - ] - - titles = [ - [ - "V projected along x", - "Ax projected along x", - "Ay projected along x", - "Az projected along x", - ], - [ - "V projected along y", - "Ax projected along y", - "Ay projected along y", - "Az projected along y", - ], - [ - "V projected along z", - "Ax projected along z", - "Ay projected along z", - "Az projected along z", - ], - ] - - max_e = np.array( - [rotated_object_Vx.max(), rotated_object_Vy.max(), rotated_object_Vz.max()] - ).max() - max_m = np.array( - [ - [ - np.abs(rotated_object_Axx).max(), - np.abs(rotated_object_Ayx).max(), - np.abs(rotated_object_Azx).max(), - ], - [ - np.abs(rotated_object_Axy).max(), - np.abs(rotated_object_Ayy).max(), - np.abs(rotated_object_Azy).max(), - ], - [ - np.abs(rotated_object_Axz).max(), - np.abs(rotated_object_Ayz).max(), - np.abs(rotated_object_Azz).max(), - ], - ] - ).max() - - vmin_e = kwargs.pop("vmin_e", 0.0) - vmax_e = kwargs.pop("vmax_e", max_e) - vmin_m = kwargs.pop("vmin_m", -max_m) - vmax_m = kwargs.pop("vmax_m", max_m) - - if plot_convergence: - spec = GridSpec( - ncols=4, nrows=4, height_ratios=[4, 4, 4, 1], hspace=0.15, wspace=0.35 - ) - else: - spec = GridSpec(ncols=4, nrows=3, hspace=0.15, wspace=0.35) - - if fig is None: - fig = plt.figure(figsize=figsize) - - for sp in spec: - row, col = np.unravel_index(sp.num1, (4, 4)) - - if row < 3: - ax = fig.add_subplot(sp) - if sp.is_first_col(): - cmap = cmap_e - vmin = vmin_e - vmax = vmax_e - else: - cmap = cmap_m - vmin = vmin_m - vmax = vmax_m - - im = ax.imshow( - arrays[row][col], - cmap=cmap, - vmin=vmin, - vmax=vmax, - extent=extent, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - ax.set_title(titles[row][col]) - - if row < 2: - ax.set_xticks([]) - else: - ax.set_xlabel("y [A]") - - if col > 0: - ax.set_yticks([]) - else: - ax.set_ylabel("x [A]") - - if plot_convergence and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - - ax = fig.add_subplot(spec[-1, :]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - plot_convergence: bool, - iterations_grid: Tuple[int, int], - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - """ - raise NotImplementedError() - - def visualize( - self, - fig=None, - cbar: bool = True, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims=(None, None), - y_lims=(None, None), - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - - Returns - -------- - self: OverlapMagneticTomographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - cbar=cbar, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - cbar=cbar, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - - return self - - def _return_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - xp = self._xp - asnumpy = self._asnumpy - - if obj is None: - obj = self._object[0] - else: - obj = xp.asarray(obj[0], dtype=xp.float32) - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = asnumpy(obj) - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - - @property - def positions(self): - """Probe positions [A]""" - - if self.angular_sampling is None: - return None - - asnumpy = self._asnumpy - positions_all = [] - for tilt_index in range(self._num_tilts): - positions = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ].copy() - positions[:, 0] *= self.sampling[0] - positions[:, 1] *= self.sampling[1] - positions_all.append(asnumpy(positions)) - - return np.asarray(positions_all) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - raise NotImplementedError() - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - raise NotImplementedError() - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py deleted file mode 100644 index 749028b83..000000000 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ /dev/null @@ -1,3286 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely overlap tomography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize import show -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg -from scipy.ndimage import rotate as rotate_np - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, - spatial_frequencies, -) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class OverlapTomographicReconstruction(PtychographicReconstruction): - """ - Overlap Tomographic Reconstruction Class. - - List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (Px,Py,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py,Py) is the padded-object electrostatic potential volume, - where x-axis is the tilt. - - Parameters - ---------- - datacube: List of DataCubes - Input list of 4D diffraction pattern intensities - energy: float - The electron energy of the wave functions in eV - num_slices: int - Number of slices to use in the forward model - tilt_orientation_matrices: Sequence[np.ndarray] - List of orientation matrices for each tilt - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py,Py) - If None, initialized to 1.0 - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: list of np.ndarray, optional - Probe positions in Å for each diffraction intensity per tilt - If None, initialized to a grid scan centered along tilt axis - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions to ignore in reconstruction - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") - _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) - - def __init__( - self, - energy: float, - num_slices: int, - tilt_orientation_matrices: Sequence[np.ndarray], - datacube: Sequence[DataCube] = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - object_type: str = "potential", - positions_mask: np.ndarray = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: Sequence[np.ndarray] = None, - verbose: bool = True, - device: str = "cpu", - name: str = "overlap-tomographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import ( - affine_transform, - gaussian_filter, - rotate, - zoom, - ) - - self._gaussian_filter = gaussian_filter - self._zoom = zoom - self._rotate = rotate - self._affine_transform = affine_transform - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - num_tilts = len(tilt_orientation_matrices) - if initial_scan_positions is None: - initial_scan_positions = [None] * num_tilts - - if object_type != "potential": - raise NotImplementedError() - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._num_slices = num_slices - self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) - self._num_tilts = num_tilts - - def _precompute_propagator_arrays( - self, - gpts: Tuple[int, int], - sampling: Tuple[float, float], - energy: float, - slice_thicknesses: Sequence[float], - ): - """ - Precomputes propagator arrays complex wave-function will be convolved by, - for all slice thicknesses. - - Parameters - ---------- - gpts: Tuple[int,int] - Wavefunction pixel dimensions - sampling: Tuple[float,float] - Wavefunction sampling in A - energy: float - The electron energy of the wave functions in eV - slice_thicknesses: Sequence[float] - Array of slice thicknesses in A - - Returns - ------- - propagator_arrays: np.ndarray - (T,Sx,Sy) shape array storing propagator arrays - """ - xp = self._xp - - # Frequencies - kx, ky = spatial_frequencies(gpts, sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) - - # Propagators - wavelength = electron_wavelength_angstrom(energy) - num_slices = slice_thicknesses.shape[0] - propagators = xp.empty( - (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 - ) - for i, dz in enumerate(slice_thicknesses): - propagators[i] = xp.exp( - 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) - ) - propagators[i] *= xp.exp( - 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) - ) - - return propagators - - def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): - """ - Propagates array by Fourier convolving array with propagator_array. - - Parameters - ---------- - array: np.ndarray - Wavefunction array to be convolved - propagator_array: np.ndarray - Propagator array to convolve array with - - Returns - ------- - propagated_array: np.ndarray - Fourier-convolved array - """ - xp = self._xp - - return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - - def _project_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(input_z / output_z).astype("int") - pad_size = voxels_per_slice * output_z - input_z - - padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) - - return xp.sum( - padded_array.reshape( - ( - -1, - voxels_per_slice, - ) - + array.shape[1:] - ), - axis=1, - ) - - def _expand_sliced_object(self, array: np.ndarray, output_z): - """ - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - xp = self._xp - input_z = array.shape[0] - - voxels_per_slice = np.ceil(output_z / input_z).astype("int") - remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) - - voxels_in_slice = xp.repeat(voxels_per_slice, input_z) - voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice - - normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] - return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] - - def _rotate_zxy_volume( - self, - volume_array, - rot_matrix, - ): - """ """ - - xp = self._xp - affine_transform = self._affine_transform - swap_zxy_to_xyz = self._swap_zxy_to_xyz - - volume = volume_array.copy() - volume_shape = xp.asarray(volume.shape) - tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) - - in_center = (volume_shape - 1) / 2 - out_center = tf @ in_center - offset = in_center - out_center - - volume = affine_transform(volume, tf, offset=offset, order=3) - - return volume - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_probe_overlaps: bool = True, - rotation_real_space_degrees: float = None, - diffraction_patterns_rotate_degrees: float = None, - diffraction_patterns_transpose: bool = None, - force_com_shifts: Sequence[float] = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - progress_bar: bool = True, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - - Additionally, it initializes an (Px,Py, Py) array of 1.0 - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - rotation_real_space_degrees: float (degrees), optional - In plane rotation around z axis between x axis and tilt axis in - real space (forced to be in xy plane) - diffraction_patterns_rotate_degrees: float, optional - Relative rotation angle between real and reciprocal space - diffraction_patterns_transpose: bool, optional - Whether diffraction intensities need to be transposed. - force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. One tuple per tilt. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) - - if self._positions_mask.ndim == 2: - warnings.warn( - "2D `positions_mask` assumed the same for all measurements.", - UserWarning, - ) - self._positions_mask = np.tile( - self._positions_mask, (self._num_tilts, 1, 1) - ) - - if self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array."), - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_tilts - - # Prepopulate various arrays - - if self._positions_mask[0] is None: - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - - num_probes_per_tilt = np.array(num_probes_per_tilt) - else: - num_probes_per_tilt = np.insert( - self._positions_mask.sum(axis=(-2, -1)), 0, 0 - ) - - self._num_diffraction_patterns = num_probes_per_tilt.sum() - self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) - - self._mean_diffraction_intensity = [] - self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) - - self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) - self._rotation_best_transpose = diffraction_patterns_transpose - - if force_com_shifts is None: - force_com_shifts = [None] * self._num_tilts - - for tilt_index in tqdmnd( - self._num_tilts, - desc="Preprocessing data", - unit="tilt", - disable=not progress_bar, - ): - if tilt_index == 0: - ( - self._datacube[tilt_index], - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts[tilt_index], - ) - - self._amplitudes = xp.empty( - (self._num_diffraction_patterns,) + self._datacube[0].Qshape - ) - self._region_of_interest_shape = np.array( - self._amplitudes[0].shape[-2:] - ) - - else: - ( - self._datacube[tilt_index], - _, - _, - force_com_shifts[tilt_index], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[tilt_index], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[tilt_index], - ) - - intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube[tilt_index], - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[tilt_index], - ) - - ( - self._amplitudes[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ], - mean_diffraction_intensity_temp, - ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, - crop_patterns, - self._positions_mask[tilt_index], - ) - - self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) - - del ( - intensities, - com_measured_x, - com_measured_y, - com_fitted_x, - com_fitted_y, - com_normalized_x, - com_normalized_y, - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index], self._positions_mask[tilt_index] - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px_all, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - self._object = xp.zeros((q, p, q), dtype=xp.float32) - else: - self._object = xp.asarray(self._object, dtype=xp.float32) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape[-2:] - self._num_voxels = self._object.shape[0] - - # Center Probes - self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32) - - for tilt_index in range(self._num_tilts): - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= ( - self._positions_px_com - xp.array(self._object_shape) / 2 - ) - - self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] = self._positions_px.copy() - - self._positions_px_initial_all = self._positions_px_all.copy() - self._positions_initial_all = self._positions_px_initial_all.copy() - self._positions_initial_all[:, 0] *= self.sampling[0] - self._positions_initial_all[:, 1] *= self.sampling[1] - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - sum(self._mean_diffraction_intensity) - / self._num_tilts - / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # Precomputed propagator arrays - self._slice_thicknesses = np.tile( - self._object_shape[1] * self.sampling[1] / self._num_slices, - self._num_slices - 1, - ) - self._propagator_arrays = self._precompute_propagator_arrays( - self._region_of_interest_shape, - self.sampling, - self._energy, - self._slice_thicknesses, - ) - - # overlaps - if object_fov_mask is None: - probe_overlap_3D = xp.zeros_like(self._object) - old_rot_matrix = np.eye(3) # identity - - for tilt_index in np.arange(self._num_tilts): - rot_matrix = self._tilt_orientation_matrices[tilt_index] - - probe_overlap_3D = self._rotate_zxy_volume( - probe_overlap_3D, - rot_matrix @ old_rot_matrix.T, - ) - - self._positions_px = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift( - self._probe, self._positions_px_fractional, xp - ) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts( - probe_intensities - ) - - probe_overlap_3D += probe_overlap[None] - old_rot_matrix = rot_matrix - - probe_overlap_3D = self._rotate_zxy_volume( - probe_overlap_3D, - old_rot_matrix.T, - ) - - probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) - self._object_fov_mask = asnumpy( - probe_overlap_3D > 0.25 * probe_overlap_3D.max() - ) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (13, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - # propagated - propagated_probe = self._probe.copy() - - for s in range(self._num_slices - 1): - propagated_probe = self._propagate_array( - propagated_probe, self._propagator_arrays[s] - ) - complex_propagated_rgb = Complex2RGB( - asnumpy(self._return_centered_probe(propagated_probe)), - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - complex_propagated_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, - chroma_boost=chroma_boost, - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_title("Propagated probe intensity") - - ax3.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax3.scatter( - self.positions[0, :, 1], - self.positions[0, :, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax3.set_ylabel("x [A]") - ax3.set_xlabel("y [A]") - ax3.set_xlim((extent[0], extent[1])) - ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - propagated_probes: np.ndarray - Shifted probes at each layer - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - """ - - xp = self._xp - - complex_object = xp.exp(1j * current_object) - object_patches = complex_object[ - :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - propagated_probes = xp.empty_like(object_patches) - propagated_probes[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes = object_patches[s] * propagated_probes[s] - - # propagate - if s + 1 < self._num_slices: - propagated_probes[s + 1] = self._propagate_array( - transmitted_probes, self._propagator_arrays[s] - ) - - return propagated_probes, object_patches, transmitted_probes - - def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - modified_exit_wave = xp.fft.ifft2( - amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves)) - ) - - exit_waves = modified_exit_wave - transmitted_probes - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, - amplitudes, - transmitted_probes, - exit_waves, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit wave difference - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = transmitted_probes.copy() - - fourier_exit_waves = xp.fft.fft2(transmitted_probes) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2) - - factor_to_be_projected = ( - projection_c * transmitted_probes + projection_y * exit_waves - ) - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * transmitted_probes - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - propagated_probes:np.ndarray - Prop[object^n*probe^n] - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - ( - propagated_probes, - object_patches, - transmitted_probes, - ) = self._overlap_projection(current_object, current_probe) - - if use_projection_scheme: - ( - exit_waves[self._active_tilt_index], - error, - ) = self._projection_sets_fourier_projection( - amplitudes, - transmitted_probes, - exit_waves[self._active_tilt_index], - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, transmitted_probes - ) - - return propagated_probes, object_patches, transmitted_probes, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves) - ) - * probe_normalization - ) - - # back-transmit - exit_waves *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves = self._propagate_array( - exit_waves, xp.conj(self._propagator_arrays[s - 1]) - ) - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += ( - step_size - * xp.sum( - exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - propagated_probes: np.ndarray - Shifted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - # careful not to modify exit_waves in-place for projection set methods - exit_waves_copy = exit_waves.copy() - for s in reversed(range(self._num_slices)): - probe = propagated_probes[s] - obj = object_patches[s] - - # object-update - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(probe) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - current_object[s] = ( - self._sum_overlapping_patches_bincounts( - xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy) - ) - * probe_normalization - ) - - # back-transmit - exit_waves_copy *= xp.conj(obj) - - if s > 0: - # back-propagate - exit_waves_copy = self._propagate_array( - exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) - ) - - elif not fix_probe: - # probe-update - object_normalization = xp.sum( - (xp.abs(obj) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - exit_waves_copy, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - transmitted_probes: np.ndarray - Transmitted probes at each layer - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves[self._active_tilt_index], - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - propagated_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _position_correction( - self, - current_object, - current_probe, - transmitted_probes, - amplitudes, - current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe:np.ndarray - fractionally-shifted probes - transmitted_probes: np.ndarray - Transmitted probes after N-1 propagations and N transmissions - amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - # Intensity gradient - exit_waves_fft = xp.fft.fft2(transmitted_probes) - exit_waves_fft_conj = xp.conj(exit_waves_fft) - estimated_intensity = xp.abs(exit_waves_fft) ** 2 - measured_intensity = amplitudes**2 - - flat_shape = (transmitted_probes.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - # Computing perturbed exit waves one at a time to save on memory - - complex_object = xp.exp(1j * current_object) - - # dx - obj_rolled_patches = complex_object[ - :, - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - # dy - obj_rolled_patches = complex_object[ - :, - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) - propagated_probes_perturbed[0] = fft_shift( - current_probe, self._positions_px_fractional, xp - ) - - for s in range(self._num_slices): - # transmit - transmitted_probes_perturbed = ( - obj_rolled_patches[s] * propagated_probes_perturbed[s] - ) - - # propagate - if s + 1 < self._num_slices: - propagated_probes_perturbed[s + 1] = self._propagate_array( - transmitted_probes_perturbed, self._propagator_arrays[s] - ) - - exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * exit_waves_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * exit_waves_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - current_positions -= positions_step_size * positions_update[..., 0] - - return current_positions - - def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - gaussian_filter = self._gaussian_filter - - gaussian_filter_sigma /= self.sampling[0] - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, current_object, q_lowpass, q_highpass, butterworth_order - ): - """ - Butterworth filter - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) - qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") - qra = xp.sqrt(qza**2 + qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) - current_object += current_object_mean - return xp.real(current_object) - - def _object_denoise_tv_pylops(self, current_object, weights, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weights : [float, float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - # zero pad at top and bottom slice - pad_width = ((1, 1), (0, 0), (0, 0)) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - # run tv denoising - nz, nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny * nz) - - if weights[0] == 0: - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[1]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - elif weights[1] == 0: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - l1_regs = [z_gradient] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weights[0]], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - else: - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] - - # remove padding - current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - - return current_object_tv - - def _constraints( - self, - current_object, - current_probe, - current_positions, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - object_positivity, - shrinkage_rad, - object_mask, - tv_denoise, - tv_denoise_weights, - tv_denoise_inner_iter, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies fourier-space butterworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weights, - tv_denoise_inner_iter, - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - tv_denoise_iter=np.inf, - tv_denoise_weights=None, - tv_denoise_inner_iter=40, - collective_tilt_updates: bool = False, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - object_positivity: bool, optional - If True, forces object to be positive - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weights: [float,float] - Denoising weights[z weight, r weight]. The greater `weight`, - the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - collective_tilt_updates: bool - if True perform collective tilt updates - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Position Correction + Collective Updates not yet implemented - if fix_positions_iter < max_iter: - raise NotImplementedError( - "Position correction is currently incompatible with collective updates." - ) - - # Batching - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = self._object_initial.copy() - self.error_iterations = [] - self._probe = self._probe_initial.copy() - self._positions_px_all = self._positions_px_initial_all.copy() - if hasattr(self, "_tf"): - del self._tf - - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - if use_projection_scheme: - self._exit_waves = [None] * self._num_tilts - else: - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if collective_tilt_updates: - collective_object = xp.zeros_like(self._object) - - tilt_indices = np.arange(self._num_tilts) - np.random.shuffle(tilt_indices) - - old_rot_matrix = np.eye(3) # identity - - for tilt_index in tilt_indices: - self._active_tilt_index = tilt_index - - tilt_error = 0.0 - - rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] - self._object = self._rotate_zxy_volume( - self._object, - rot_matrix @ old_rot_matrix.T, - ) - - object_sliced = self._project_sliced_object( - self._object, self._num_slices - ) - if not use_projection_scheme: - object_sliced_old = object_sliced.copy() - - start_tilt = self._cum_probes_per_tilt[self._active_tilt_index] - end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1] - - num_diffraction_patterns = end_tilt - start_tilt - shuffled_indices = np.arange(num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - - unshuffled_indices[shuffled_indices] = np.arange( - num_diffraction_patterns - ) - - positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[ - shuffled_indices - ] - initial_positions_px = self._positions_px_initial_all[ - start_tilt:end_tilt - ].copy()[shuffled_indices] - - for start, end in generate_batches( - num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_initial = initial_positions_px[start:end] - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - amplitudes = self._amplitudes[start_tilt:end_tilt][ - shuffled_indices[start:end] - ] - - # forward operator - ( - propagated_probes, - object_patches, - transmitted_probes, - self._exit_waves, - batch_error, - ) = self._forward( - object_sliced, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - object_sliced, self._probe = self._adjoint( - object_sliced, - self._probe, - object_patches, - propagated_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - object_sliced, - self._probe, - transmitted_probes, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - tilt_error += batch_error - - if not use_projection_scheme: - object_sliced -= object_sliced_old - - object_update = self._expand_sliced_object( - object_sliced, self._num_voxels - ) - - if collective_tilt_updates: - collective_object += self._rotate_zxy_volume( - object_update, rot_matrix.T - ) - else: - self._object += object_update - - old_rot_matrix = rot_matrix - - # Normalize Error - tilt_error /= ( - self._mean_diffraction_intensity[self._active_tilt_index] - * num_diffraction_patterns - ) - error += tilt_error - - # constraints - self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[ - unshuffled_indices - ] - - if not collective_tilt_updates: - ( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - ) = self._constraints( - self._object, - self._probe, - self._positions_px_all[start_tilt:end_tilt], - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter - and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) - - # Normalize Error Over Tilts - error /= self._num_tilts - - if collective_tilt_updates: - self._object += collective_object / self._num_tilts - - ( - self._object, - self._probe, - _, - ) = self._constraints( - self._object, - self._probe, - None, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=True, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline - and self._object_fov_mask_inverse.sum() > 0 - else None, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, - tv_denoise_weights=tv_denoise_weights, - tv_denoise_inner_iter=tv_denoise_inner_iter, - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _crop_rotate_object_manually( - self, - array, - angle, - x_lims, - y_lims, - ): - """ - Crops and rotates rotates object manually. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - angle: float - In-plane angle in degrees to rotate by - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - ------- - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ - - asnumpy = self._asnumpy - min_x, max_x = x_lims - min_y, max_y = y_lims - - if angle is not None: - rotated_array = rotate_np( - asnumpy(array), angle, reshape=False, axes=(-2, -1) - ) - else: - rotated_array = asnumpy(array) - - return rotated_array[..., min_x:max_x, min_y:max_y] - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax, - cbar: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - cmap = kwargs.pop("cmap", "magma") - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - asnumpy = self._asnumpy - - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - asnumpy = self._asnumpy - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - self._object, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = self.object - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed object projection") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - ax_cb, - chroma_boost=chroma_boost, - ) - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed object projection") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - projection_angle_deg: float, - projection_axes: Tuple[int, int], - x_lims: Tuple[int, int], - y_lims: Tuple[int, int], - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - if projection_angle_deg is not None: - objects = [ - self._crop_rotate_object_manually( - rotate_np( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ).sum(0), - angle=None, - x_lims=x_lims, - y_lims=y_lims, - ) - for obj in self.object_iterations - ] - else: - objects = [ - self._crop_rotate_object_manually( - obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - for obj in self.object_iterations - ] - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) if plot_probe else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} Object") - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - probes[grid_range[n]], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims=(None, None), - y_lims=(None, None), - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - - Returns - -------- - self: OverlapTomographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - **kwargs, - ) - - return self - - def _return_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - ): - """ - Returns obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - - xp = self._xp - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - else: - obj = xp.asarray(obj, dtype=xp.float32) - - if projection_angle_deg is not None: - rotated_3d_obj = self._rotate( - obj, - projection_angle_deg, - axes=projection_axes, - reshape=False, - order=2, - ) - rotated_3d_obj = asnumpy(rotated_3d_obj) - else: - rotated_3d_obj = asnumpy(obj) - - rotated_object = self._crop_rotate_object_manually( - rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims - ) - - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) - - def show_object_fft( - self, - obj=None, - projection_angle_deg: float = None, - projection_axes: Tuple[int, int] = (0, 2), - x_lims: Tuple[int, int] = (None, None), - y_lims: Tuple[int, int] = (None, None), - **kwargs, - ): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - projection_angle_deg: float - Angle in degrees to rotate 3D array around prior to projection - projection_axes: tuple(int,int) - Axes defining projection plane - x_lims: tuple(float,float) - min/max x indices - y_lims: tuple(float,float) - min/max y indices - """ - if obj is None: - object_fft = self._return_object_fft( - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - else: - object_fft = self._return_object_fft( - obj, - projection_angle_deg=projection_angle_deg, - projection_axes=projection_axes, - x_lims=x_lims, - y_lims=y_lims, - ) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - - @property - def positions(self): - """Probe positions [A]""" - - if self.angular_sampling is None: - return None - - asnumpy = self._asnumpy - positions_all = [] - for tilt_index in range(self._num_tilts): - positions = self._positions_px_all[ - self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[ - tilt_index + 1 - ] - ].copy() - positions[:, 0] *= self.sampling[0] - positions[:, 1] *= self.sampling[1] - positions_all.append(asnumpy(positions)) - - return np.asarray(positions_all) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - raise NotImplementedError() - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - raise NotImplementedError() - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, - **kwargs, - ): - """Plot uncertainty visualization using self-consistency errors""" - raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py deleted file mode 100644 index 59bf61da2..000000000 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ /dev/null @@ -1,647 +0,0 @@ -import warnings - -import numpy as np -from py4DSTEM.process.phase.utils import ( - array_slice, - estimate_global_transformation_ransac, - fft_shift, - fit_aberration_surface, - regularize_probe_amplitude, -) -from py4DSTEM.process.utils import get_CoM - -try: - import cupy as cp -except (ModuleNotFoundError, ImportError): - cp = np - import os - - # make sure pylops doesn't try to use cupy - os.environ["CUPY_PYLOPS"] = "0" -import pylops # this must follow the exception - - -class PtychographicConstraints: - """ - Container class for PtychographicReconstruction methods. - """ - - def _object_threshold_constraint(self, current_object, pure_phase_object): - """ - Ptychographic threshold constraint. - Used for avoiding the scaling ambiguity between probe and object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - pure_phase_object: bool - If True, object amplitude is set to unity - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - phase = xp.angle(current_object) - - if pure_phase_object: - amplitude = 1.0 - else: - amplitude = xp.minimum(xp.abs(current_object), 1.0) - - return amplitude * xp.exp(1.0j * phase) - - def _object_shrinkage_constraint(self, current_object, shrinkage_rad, object_mask): - """ - Ptychographic shrinkage constraint. - Used to ensure electrostatic potential is positive. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - if self._object_type == "complex": - phase = xp.angle(current_object) - amp = xp.abs(current_object) - - if object_mask is not None: - shrinkage_rad += phase[..., object_mask].mean() - - phase -= shrinkage_rad - - current_object = amp * xp.exp(1.0j * phase) - else: - if object_mask is not None: - shrinkage_rad += current_object[..., object_mask].mean() - - current_object -= shrinkage_rad - - return current_object - - def _object_positivity_constraint(self, current_object): - """ - Ptychographic positivity constraint. - Used to ensure potential is positive. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - - return xp.maximum(current_object, 0.0) - - def _object_gaussian_constraint( - self, current_object, gaussian_filter_sigma, pure_phase_object - ): - """ - Ptychographic smoothness constraint. - Used for blurring object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - pure_phase_object: bool - If True, gaussian blur performed on phase only - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - gaussian_filter = self._gaussian_filter - gaussian_filter_sigma /= self.sampling[0] - - if pure_phase_object: - phase = xp.angle(current_object) - phase = gaussian_filter(phase, gaussian_filter_sigma) - current_object = xp.exp(1.0j * phase) - else: - current_object = gaussian_filter(current_object, gaussian_filter_sigma) - - return current_object - - def _object_butterworth_constraint( - self, - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ): - """ - Ptychographic butterworth filter. - Used for low/high-pass filtering object. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - """ - xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1]) - - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - - env = xp.ones_like(qra) - if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) - if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) - - current_object_mean = xp.mean(current_object) - current_object -= current_object_mean - current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) - current_object += current_object_mean - - if self._object_type == "potential": - current_object = xp.real(current_object) - - return current_object - - def _object_denoise_tv_pylops(self, current_object, weight, iterations): - """ - Performs second order TV denoising along x and y - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weight : float - Denoising weight. The greater `weight`, the more denoising (at - the expense of fidelity to `input`). - iterations: float - Number of iterations to run in denoising algorithm. - `niter_out` in pylops - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - """ - xp = self._xp - - if xp.iscomplexobj(current_object): - current_object_tv = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - - else: - nx, ny = current_object.shape - niter_out = iterations - niter_in = 1 - Iop = pylops.Identity(nx * ny) - xy_laplacian = pylops.Laplacian( - (nx, ny), axes=(0, 1), edge=False, kind="backward" - ) - - l1_regs = [xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=[weight], - tol=1e-4, - tau=1.0, - show=False, - )[0] - - current_object_tv = current_object_tv.reshape(current_object.shape) - - return current_object_tv - - def _object_denoise_tv_chambolle( - self, - current_object, - weight, - axis, - pad_object, - eps=2.0e-4, - max_num_iter=200, - scaling=None, - ): - """ - Perform total-variation denoising on n-dimensional images. - - Parameters - ---------- - current_object: np.ndarray - Current object estimate - weight : float, optional - Denoising weight. The greater `weight`, the more denoising (at - the expense of fidelity to `input`). - axis: int or tuple - Axis for denoising, if None uses all axes - pad_object: bool - if True, pads object with zeros along axes of blurring - eps : float, optional - Relative difference of the value of the cost function that determines - the stop criterion. The algorithm stops when: - - (E_(n-1) - E_n) < eps * E_0 - - max_num_iter : int, optional - Maximal number of iterations used for the optimization. - scaling : tuple, optional - Scale weight of tv denoise on different axes - - Returns - ------- - constrained_object: np.ndarray - Constrained object estimate - - Notes - ----- - Rudin, Osher and Fatemi algorithm. - Adapted skimage.restoration.denoise_tv_chambolle. - """ - xp = self._xp - if xp.iscomplexobj(current_object): - updated_object = current_object - warnings.warn( - ("TV denoising is currently only supported for potential objects."), - UserWarning, - ) - else: - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) - else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" - ) - - p = xp.zeros( - (current_object.ndim,) + current_object.shape, - dtype=current_object.dtype, - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: - E_previous = E - i += 1 - - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] - updated_object = ( - updated_object / xp.sum(updated_object) * current_object_sum - ) - - return updated_object - - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic center of mass constraint. - Used for centering corner-centered probe intensity. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - - probe_intensity = xp.abs(current_probe) ** 2 - - probe_x0, probe_y0 = get_CoM( - probe_intensity, device=self._device, corner_centered=True - ) - shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) - - return shifted_probe - - def _probe_amplitude_constraint( - self, current_probe, relative_radius, relative_width - ): - """ - Ptychographic top-hat filtering of probe. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - erf = self._erf - - probe_intensity = xp.abs(current_probe) ** 2 - current_probe_sum = xp.sum(probe_intensity) - - X = xp.fft.fftfreq(current_probe.shape[0])[:, None] - Y = xp.fft.fftfreq(current_probe.shape[1])[None] - r = xp.hypot(X, Y) - relative_radius - - sigma = np.sqrt(np.pi) / relative_width - tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) - - updated_probe = current_probe * tophat_mask - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - - return updated_probe * normalization - - def _probe_fourier_amplitude_constraint( - self, - current_probe, - width_max_pixels, - enforce_constant_intensity, - ): - """ - Ptychographic top-hat filtering of Fourier probe. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - threshold: np.ndarray - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where 1 uses the maximum amplitude to threshold. - relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - asnumpy = self._asnumpy - - current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) - current_probe_fft = xp.fft.fft2(current_probe) - - updated_probe_fft, _, _, _ = regularize_probe_amplitude( - asnumpy(current_probe_fft), - width_max_pixels=width_max_pixels, - nearest_angular_neighbor_averaging=5, - enforce_constant_intensity=enforce_constant_intensity, - corner_centered=True, - ) - - updated_probe_fft = xp.asarray(updated_probe_fft) - updated_probe = xp.fft.ifft2(updated_probe_fft) - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - - return updated_probe * normalization - - def _probe_aperture_constraint( - self, - current_probe, - initial_probe_aperture, - ): - """ - Ptychographic constraint to fix Fourier amplitude to initial aperture. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - - current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) - current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) - - updated_probe = xp.fft.ifft2( - xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture - ) - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - - return updated_probe * normalization - - def _probe_aberration_fitting_constraint( - self, - current_probe, - max_angular_order, - max_radial_order, - ): - """ - Ptychographic probe smoothing constraint. - Removes/adds known (initialization) aberrations before/after smoothing. - - Parameters - ---------- - current_probe: np.ndarray - Current positions estimate - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - fix_amplitude: bool - If True, only the phase is smoothed - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - - xp = self._xp - - fourier_probe = xp.fft.fft2(current_probe) - fourier_probe_abs = xp.abs(fourier_probe) - sampling = self.sampling - energy = self._energy - - fitted_angle, _ = fit_aberration_surface( - fourier_probe, - sampling, - energy, - max_angular_order, - max_radial_order, - xp=xp, - ) - - fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) - current_probe = xp.fft.ifft2(fourier_probe) - - return current_probe - - def _positions_center_of_mass_constraint(self, current_positions): - """ - Ptychographic position center of mass constraint. - Additionally updates vectorized indices used in _overlap_projection. - - Parameters - ---------- - current_positions: np.ndarray - Current positions estimate - - Returns - -------- - constrained_positions: np.ndarray - CoM constrained positions estimate - """ - xp = self._xp - - current_positions -= xp.mean(current_positions, axis=0) - self._positions_px_com - self._positions_px_fractional = current_positions - xp.round(current_positions) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - return current_positions - - def _positions_affine_transformation_constraint( - self, initial_positions, current_positions - ): - """ - Constrains the updated positions to be an affine transformation of the initial scan positions, - composing of two scale factors, a shear, and a rotation angle. - - Uses RANSAC to estimate the global transformation robustly. - Stores the AffineTransformation in self._tf. - - Parameters - ---------- - initial_positions: np.ndarray - Initial scan positions - current_positions: np.ndarray - Current positions estimate - - Returns - ------- - constrained_positions: np.ndarray - Affine-transform constrained positions estimate - """ - - xp = self._xp - - tf, _ = estimate_global_transformation_ransac( - positions0=initial_positions, - positions1=current_positions, - origin=self._positions_px_com, - translation_allowed=True, - min_sample=self._num_diffraction_patterns // 10, - xp=xp, - ) - - self._tf = tf - current_positions = tf(initial_positions, origin=self._positions_px_com, xp=xp) - - return current_positions diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py deleted file mode 100644 index c8cc5ee3e..000000000 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ /dev/null @@ -1,3510 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely joint ptychography. -""" - -import warnings -from typing import Mapping, Sequence, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg - -try: - import cupy as cp -except (ImportError, ModuleNotFoundError): - cp = np - -from emdfile import Custom, tqdmnd -from py4DSTEM import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, -) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class SimultaneousPtychographicReconstruction(PtychographicReconstruction): - """ - Iterative Simultaneous Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed electrostatic dimensions : (Px,Py) - Reconstructed magnetic dimensions : (Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py) is the padded-object size we position our ROI around in. - - Parameters - ---------- - datacube: Sequence[DataCube] - Tuple of input 4D diffraction pattern intensities - energy: float - The electron energy of the wave functions in eV - simultaneous_measurements_mode: str, optional - One of '-+', '-0+', '0+', where -/0/+ refer to the sign of the magnetic potential - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad objects with - If None, the padding is set to half the probe ROI dimensions - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = ("_simultaneous_measurements_mode",) - - def __init__( - self, - energy: float, - datacube: Sequence[DataCube] = None, - simultaneous_measurements_mode: str = "-+", - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - object_padding_px: Tuple[int, int] = None, - positions_mask: np.ndarray = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - object_type: str = "complex", - verbose: bool = True, - device: str = "cpu", - name: str = "simultaneous_ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - self._simultaneous_measurements_mode = simultaneous_measurements_mode - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin', 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._simultaneous_measurements_mode == "-+": - self._sim_recon_mode = 0 - self._num_sim_measurements = 2 - if self._verbose: - print( - ( - "Magnetic vector potential sign in first meaurement assumed to be negative.\n" - "Magnetic vector potential sign in second meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 2: - raise ValueError( - f"datacube must be a set of two measurements, not length {len(self._datacube)}." - ) - if self._datacube[0].shape != self._datacube[1].shape: - raise ValueError("datacube intensities must be the same size.") - elif self._simultaneous_measurements_mode == "-0+": - self._sim_recon_mode = 1 - self._num_sim_measurements = 3 - if self._verbose: - print( - ( - "Magnetic vector potential sign in first meaurement assumed to be negative.\n" - "Magnetic vector potential assumed to be zero in second meaurement.\n" - "Magnetic vector potential sign in third meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 3: - raise ValueError( - f"datacube must be a set of three measurements, not length {len(self._datacube)}." - ) - if ( - self._datacube[0].shape != self._datacube[1].shape - or self._datacube[0].shape != self._datacube[2].shape - ): - raise ValueError("datacube intensities must be the same size.") - elif self._simultaneous_measurements_mode == "0+": - self._sim_recon_mode = 2 - self._num_sim_measurements = 2 - if self._verbose: - print( - ( - "Magnetic vector potential assumed to be zero in first meaurement.\n" - "Magnetic vector potential sign in second meaurement assumed to be positive." - ) - ) - if len(self._datacube) != 2: - raise ValueError( - f"datacube must be a set of two measurements, not length {len(self._datacube)}." - ) - if self._datacube[0].shape != self._datacube[1].shape: - raise ValueError("datacube intensities must be the same size.") - else: - raise ValueError( - f"simultaneous_measurements_mode must be either '-+', '-0+', or '0+', not {self._simultaneous_measurements_mode}" - ) - - if self._positions_mask is not None: - self._positions_mask = np.asarray(self._positions_mask) - - if self._positions_mask.ndim == 2: - warnings.warn( - "2D `positions_mask` assumed the same for all measurements.", - UserWarning, - ) - self._positions_mask = np.tile( - self._positions_mask, (self._num_sim_measurements, 1, 1) - ) - - if self._positions_mask.dtype != "bool": - warnings.warn( - "`positions_mask` converted to `bool` array.", - UserWarning, - ) - self._positions_mask = self._positions_mask.astype("bool") - else: - self._positions_mask = [None] * self._num_sim_measurements - - if force_com_shifts is None: - force_com_shifts = [None, None, None] - elif len(force_com_shifts) == self._num_sim_measurements: - force_com_shifts = list(force_com_shifts) - else: - raise ValueError( - ( - "force_com_shifts must be a sequence of tuples " - "with the same length as the datasets." - ) - ) - - # Ensure plot_center_of_mass is not in kwargs - kwargs.pop("plot_center_of_mass", None) - - # 1st measurement sets rotation angle and transposition - ( - measurement_0, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts[0], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[0], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts[0], - ) - - intensities_0 = self._extract_intensities_and_calibrations_from_datacube( - measurement_0, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_0, - com_measured_y_0, - com_fitted_x_0, - com_fitted_y_0, - com_normalized_x_0, - com_normalized_y_0, - ) = self._calculate_intensities_center_of_mass( - intensities_0, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[0], - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - _com_x_0, - _com_y_0, - com_x_0, - com_y_0, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_0, - com_measured_y_0, - com_normalized_x_0, - com_normalized_y_0, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - amplitudes_0, - mean_diffraction_intensity_0, - ) = self._normalize_diffraction_intensities( - intensities_0, - com_fitted_x_0, - com_fitted_y_0, - crop_patterns, - self._positions_mask[0], - ) - - # explicitly delete namescapes - del ( - intensities_0, - com_measured_x_0, - com_measured_y_0, - com_fitted_x_0, - com_fitted_y_0, - com_normalized_x_0, - com_normalized_y_0, - _com_x_0, - _com_y_0, - com_x_0, - com_y_0, - ) - - # 2nd measurement - ( - measurement_1, - _, - _, - force_com_shifts[1], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[1], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[1], - ) - - intensities_1 = self._extract_intensities_and_calibrations_from_datacube( - measurement_1, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_1, - com_measured_y_1, - com_fitted_x_1, - com_fitted_y_1, - com_normalized_x_1, - com_normalized_y_1, - ) = self._calculate_intensities_center_of_mass( - intensities_1, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[1], - ) - - ( - _, - _, - _com_x_1, - _com_y_1, - com_x_1, - com_y_1, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_1, - com_measured_y_1, - com_normalized_x_1, - com_normalized_y_1, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=np.rad2deg(self._rotation_best_rad), - force_com_transpose=self._rotation_best_transpose, - **kwargs, - ) - - ( - amplitudes_1, - mean_diffraction_intensity_1, - ) = self._normalize_diffraction_intensities( - intensities_1, - com_fitted_x_1, - com_fitted_y_1, - crop_patterns, - self._positions_mask[1], - ) - - # explicitly delete namescapes - del ( - intensities_1, - com_measured_x_1, - com_measured_y_1, - com_fitted_x_1, - com_fitted_y_1, - com_normalized_x_1, - com_normalized_y_1, - _com_x_1, - _com_y_1, - com_x_1, - com_y_1, - ) - - # Optionally, 3rd measurement - if self._num_sim_measurements == 3: - ( - measurement_2, - _, - _, - force_com_shifts[2], - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube[2], - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=None, - dp_mask=None, - com_shifts=force_com_shifts[2], - ) - - intensities_2 = self._extract_intensities_and_calibrations_from_datacube( - measurement_2, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - com_measured_x_2, - com_measured_y_2, - com_fitted_x_2, - com_fitted_y_2, - com_normalized_x_2, - com_normalized_y_2, - ) = self._calculate_intensities_center_of_mass( - intensities_2, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts[2], - ) - - ( - _, - _, - _com_x_2, - _com_y_2, - com_x_2, - com_y_2, - ) = self._solve_for_center_of_mass_relative_rotation( - com_measured_x_2, - com_measured_y_2, - com_normalized_x_2, - com_normalized_y_2, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=False, - maximize_divergence=maximize_divergence, - force_com_rotation=np.rad2deg(self._rotation_best_rad), - force_com_transpose=self._rotation_best_transpose, - **kwargs, - ) - - ( - amplitudes_2, - mean_diffraction_intensity_2, - ) = self._normalize_diffraction_intensities( - intensities_2, - com_fitted_x_2, - com_fitted_y_2, - crop_patterns, - self._positions_mask[2], - ) - - # explicitly delete namescapes - del ( - intensities_2, - com_measured_x_2, - com_measured_y_2, - com_fitted_x_2, - com_fitted_y_2, - com_normalized_x_2, - com_normalized_y_2, - _com_x_2, - _com_y_2, - com_x_2, - com_y_2, - ) - - self._amplitudes = (amplitudes_0, amplitudes_1, amplitudes_2) - self._mean_diffraction_intensity = ( - mean_diffraction_intensity_0 - + mean_diffraction_intensity_1 - + mean_diffraction_intensity_2 - ) / 3 - - del amplitudes_0, amplitudes_1, amplitudes_2 - - else: - self._amplitudes = (amplitudes_0, amplitudes_1) - self._mean_diffraction_intensity = ( - mean_diffraction_intensity_0 + mean_diffraction_intensity_1 - ) / 2 - - del amplitudes_0, amplitudes_1 - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes[0].shape[0] - self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask[0] - ) # TO-DO: generaltize to per-dataset probe positions - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - object_e = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - object_e = xp.ones((p, q), dtype=xp.complex64) - object_m = xp.zeros((p, q), dtype=xp.float32) - else: - if self._object_type == "potential": - object_e = xp.asarray(self._object[0], dtype=xp.float32) - elif self._object_type == "complex": - object_e = xp.asarray(self._object[0], dtype=xp.complex64) - object_m = xp.asarray(self._object[1], dtype=xp.float32) - - self._object = (object_e, object_m) - self._object_initial = (object_e.copy(), object_m.copy()) - self._object_type_initial = self._object_type - self._object_shape = self._object[0].shape - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (9, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, - chroma_boost=chroma_boost, - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="Greys_r", - ) - ax2.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_xlim((extent[0], extent[1])) - ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _warmup_overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - electrostatic_obj, _ = current_object - - if self._object_type == "potential": - complex_object = xp.exp(1j * electrostatic_obj) - else: - complex_object = electrostatic_obj - - electrostatic_obj_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - object_patches = (electrostatic_obj_patches, None) - overlap = (shifted_probes * electrostatic_obj_patches, None) - - return shifted_probes, object_patches, overlap - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - electrostatic_obj, magnetic_obj = current_object - - if self._object_type == "potential": - complex_object_e = xp.exp(1j * electrostatic_obj) - else: - complex_object_e = electrostatic_obj - - complex_object_m = xp.exp(1j * magnetic_obj) - - electrostatic_obj_patches = complex_object_e[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - magnetic_obj_patches = complex_object_m[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - object_patches = (electrostatic_obj_patches, magnetic_obj_patches) - - if self._sim_recon_mode == 0: - overlap_reverse = ( - shifted_probes - * electrostatic_obj_patches - * xp.conj(magnetic_obj_patches) - ) - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_reverse, overlap_forward) - elif self._sim_recon_mode == 1: - overlap_reverse = ( - shifted_probes - * electrostatic_obj_patches - * xp.conj(magnetic_obj_patches) - ) - overlap_neutral = shifted_probes * electrostatic_obj_patches - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_reverse, overlap_neutral, overlap_forward) - else: - overlap_neutral = shifted_probes * electrostatic_obj_patches - overlap_forward = ( - shifted_probes * electrostatic_obj_patches * magnetic_obj_patches - ) - overlap = (overlap_neutral, overlap_forward) - - return shifted_probes, object_patches, overlap - - def _warmup_gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - - fourier_overlap = xp.fft.fft2(overlap[0]) - error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2) - - fourier_modified_overlap = amplitudes[0] * xp.exp( - 1j * xp.angle(fourier_overlap) - ) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = (modified_overlap - overlap[0],) + (None,) * ( - self._num_sim_measurements - 1 - ) - - return exit_waves, error - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - - error = 0.0 - exit_waves = [] - for amp, overl in zip(amplitudes, overlap): - fourier_overl = xp.fft.fft2(overl) - error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2) - - fourier_modified_overl = amp * xp.exp(1j * xp.angle(fourier_overl)) - modified_overl = xp.fft.ifft2(fourier_modified_overl) - - exit_waves.append(modified_overl - overl) - - error /= len(exit_waves) - exit_waves = tuple(exit_waves) - - return exit_waves, error - - def _warmup_projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - exit_wave = exit_waves[0] - - if exit_wave is None: - exit_wave = overlap[0].copy() - - fourier_overlap = xp.fft.fft2(overlap[0]) - error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2) - - factor_to_be_projected = projection_c * overlap[0] + projection_y * exit_wave - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes[0] * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_wave = ( - projection_x * exit_wave - + projection_a * overlap[0] - + projection_b * projected_factor - ) - - exit_waves = (exit_wave,) + (None,) * (self._num_sim_measurements - 1) - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - error = 0.0 - _exit_waves = [] - for amp, overl, exit_wave in zip(amplitudes, overlap, exit_waves): - if exit_wave is None: - exit_wave = overl.copy() - - fourier_overl = xp.fft.fft2(overl) - error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2) - - factor_to_be_projected = projection_c * overl + projection_y * exit_wave - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amp * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - _exit_waves.append( - projection_x * exit_wave - + projection_a * overl - + projection_b * projected_factor - ) - - error /= len(_exit_waves) - exit_waves = tuple(_exit_waves) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - warmup_iteration, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - if warmup_iteration: - shifted_probes, object_patches, overlap = self._warmup_overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._warmup_projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._warmup_gradient_descent_fourier_projection( - amplitudes, overlap - ) - - else: - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _warmup_gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, _ = current_object - electrostatic_obj_patches, _ = object_patches - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(electrostatic_obj_patches) - * xp.conj(shifted_probes) - * exit_waves[0] - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves[0] - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(electrostatic_obj_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.conj(electrostatic_obj_patches) * exit_waves[0], - axis=0, - ) - * object_normalization - ) - - return (electrostatic_obj, None), current_probe - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, magnetic_obj = current_object - probe_conj = xp.conj(shifted_probes) - - electrostatic_obj_patches, magnetic_obj_patches = object_patches - electrostatic_conj = xp.conj(electrostatic_obj_patches) - magnetic_conj = xp.conj(magnetic_obj_patches) - - probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches) - - probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( - probe_electrostatic_abs**2 - ) - probe_electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 - + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 - ) - - probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2 - ) - probe_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 - + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 - ) - - if self._sim_recon_mode > 0: - probe_abs = xp.abs(shifted_probes) - probe_normalization = self._sum_overlapping_patches_bincounts( - probe_abs**2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - exit_waves_reverse, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - - elif self._object_type == "complex": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - ) - / 2 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - 1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - - elif self._sim_recon_mode == 1: - exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_magnetic_normalization - ) - / 3 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_neutral - ) - ) - * probe_normalization - ) - / 3 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 3 - ) - - elif self._object_type == "complex": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - ) - / 3 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves_neutral - ) - * probe_normalization - / 3 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 3 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - 1j - * magnetic_obj_patches - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_reverse - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 2 - ) - - else: - exit_waves_neutral, exit_waves_forward = exit_waves - - if self._object_type == "potential": - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_neutral - ) - ) - * probe_normalization - ) - / 2 - ) - electrostatic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_magnetic_normalization - ) - / 2 - ) - - elif self._object_type == "complex": - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * exit_waves_neutral - ) - * probe_normalization - / 2 - ) - electrostatic_obj += step_size * ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj += ( - step_size - * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * magnetic_conj - * electrostatic_conj - * xp.conj(shifted_probes) - * exit_waves_forward - ) - ) - * probe_electrostatic_normalization - ) - / 3 - ) - - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - electrostatic_obj_patches * magnetic_obj_patches - ) - electrostatic_magnetic_normalization = xp.sum( - electrostatic_magnetic_abs**2, - axis=0, - ) - electrostatic_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_magnetic_normalization)) - ** 2 - ) - - if self._sim_recon_mode > 0: - electrostatic_abs = xp.abs(electrostatic_obj_patches) - electrostatic_normalization = xp.sum( - electrostatic_abs**2, - axis=0, - ) - electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 3 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - else: - current_probe += step_size * ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 2 - ) - - current_probe += step_size * ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_object = (electrostatic_obj, magnetic_obj) - - return current_object, current_probe - - def _warmup_projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - electrostatic_obj, _ = current_object - electrostatic_obj_patches, _ = object_patches - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves[0] - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(electrostatic_obj_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.conj(electrostatic_obj_patches) * exit_waves[0], - axis=0, - ) - * object_normalization - ) - - return (electrostatic_obj, None), current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - xp = self._xp - - electrostatic_obj, magnetic_obj = current_object - probe_conj = xp.conj(shifted_probes) - - electrostatic_obj_patches, magnetic_obj_patches = object_patches - electrostatic_conj = xp.conj(electrostatic_obj_patches) - magnetic_conj = xp.conj(magnetic_obj_patches) - - probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches) - probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches) - - probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( - probe_electrostatic_abs**2 - ) - probe_electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 - + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 - ) - - probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( - probe_magnetic_abs**2 - ) - probe_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 - + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 - ) - - if self._sim_recon_mode > 0: - probe_abs = xp.abs(shifted_probes) - - probe_normalization = self._sum_overlapping_patches_bincounts( - probe_abs**2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - exit_waves_reverse, exit_waves_forward = exit_waves - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - / 2 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 2 - ) - - magnetic_obj = xp.conj( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_reverse - ) - * probe_electrostatic_normalization - / 2 - ) - - magnetic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_forward - ) - * probe_electrostatic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves - - electrostatic_obj = ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_obj_patches * exit_waves_reverse - ) - * probe_magnetic_normalization - / 3 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts(probe_conj * exit_waves_neutral) - * probe_normalization - / 3 - ) - - electrostatic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * magnetic_conj * exit_waves_forward - ) - * probe_magnetic_normalization - / 3 - ) - - magnetic_obj = xp.conj( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_reverse - ) - * probe_electrostatic_normalization - / 2 - ) - - magnetic_obj += ( - self._sum_overlapping_patches_bincounts( - probe_conj * electrostatic_conj * exit_waves_forward - ) - * probe_electrostatic_normalization - / 2 - ) - - else: - raise NotImplementedError() - - if not fix_probe: - electrostatic_magnetic_abs = xp.abs( - electrostatic_obj_patches * magnetic_obj_patches - ) - - electrostatic_magnetic_normalization = xp.sum( - (electrostatic_magnetic_abs**2), - axis=0, - ) - electrostatic_magnetic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_magnetic_normalization)) - ** 2 - ) - - if self._sim_recon_mode > 0: - electrostatic_abs = xp.abs(electrostatic_obj_patches) - electrostatic_normalization = xp.sum( - (electrostatic_abs**2), - axis=0, - ) - electrostatic_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * electrostatic_normalization) ** 2 - + (normalization_min * xp.max(electrostatic_normalization)) ** 2 - ) - - if self._sim_recon_mode == 0: - current_probe = ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - elif self._sim_recon_mode == 1: - current_probe = ( - xp.sum( - electrostatic_conj * magnetic_obj_patches * exit_waves_reverse, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 3 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 3 - ) - else: - current_probe = ( - xp.sum( - electrostatic_conj * exit_waves_neutral, - axis=0, - ) - * electrostatic_normalization - / 2 - ) - - current_probe += ( - xp.sum( - electrostatic_conj * magnetic_conj * exit_waves_forward, - axis=0, - ) - * electrostatic_magnetic_normalization - / 2 - ) - - current_object = (electrostatic_obj, magnetic_obj) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - warmup_iteration: bool, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if warmup_iteration: - if use_projection_scheme: - current_object, current_probe = self._warmup_projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._warmup_gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - else: - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma_e, - gaussian_filter_sigma_m, - butterworth_filter, - q_lowpass_e, - q_lowpass_m, - q_highpass_e, - q_highpass_m, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - warmup_iteration, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - warmup_iteration: bool - If True, constraints electrostatic object only - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - electrostatic_obj, magnetic_obj = current_object - - if gaussian_filter: - electrostatic_obj = self._object_gaussian_constraint( - electrostatic_obj, gaussian_filter_sigma_e, pure_phase_object - ) - if not warmup_iteration: - magnetic_obj = self._object_gaussian_constraint( - magnetic_obj, - gaussian_filter_sigma_m, - pure_phase_object, - ) - - if butterworth_filter: - electrostatic_obj = self._object_butterworth_constraint( - electrostatic_obj, - q_lowpass_e, - q_highpass_e, - butterworth_order, - ) - if not warmup_iteration: - magnetic_obj = self._object_butterworth_constraint( - magnetic_obj, - q_lowpass_m, - q_highpass_m, - butterworth_order, - ) - - if self._object_type == "complex": - magnetic_obj = magnetic_obj.real - if tv_denoise: - electrostatic_obj = self._object_denoise_tv_pylops( - electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter - ) - - if not warmup_iteration: - magnetic_obj = self._object_denoise_tv_pylops( - magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - electrostatic_obj = self._object_shrinkage_constraint( - electrostatic_obj, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - electrostatic_obj = self._object_threshold_constraint( - electrostatic_obj, pure_phase_object - ) - elif object_positivity: - electrostatic_obj = self._object_positivity_constraint(electrostatic_obj) - - current_object = (electrostatic_obj, magnetic_obj) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, - fix_com: bool = True, - fix_probe_iter: int = 0, - warmup_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma_e: float = None, - gaussian_filter_sigma_m: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass_e: float = None, - q_lowpass_m: float = None, - q_highpass_e: float = None, - q_highpass_m: float = None, - butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, - tv_denoise_weight: float = None, - tv_denoise_inner_iter: float = 40, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - pure_phase_object_iter: float, optional - Number of iterations where object amplitude is set to unity - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma_e: float - Standard deviation of gaussian kernel for electrostatic object in A - gaussian_filter_sigma_m: float - Standard deviation of gaussian kernel for magnetic object in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass_e: float - Cut-off frequency in A^-1 for low-pass filtering electrostatic object - q_lowpass_m: float - Cut-off frequency in A^-1 for low-pass filtering magnetic object - q_highpass_e: float - Cut-off frequency in A^-1 for high-pass filtering electrostatic object - q_highpass_m: float - Cut-off frequency in A^-1 for high-pass filtering magnetic object - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if use_projection_scheme and self._sim_recon_mode == 2: - raise NotImplementedError( - "simultaneous_measurements_mode == '0+' and projection set algorithms are currently incompatible." - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self._object = ( - self._object_initial[0].copy(), - self._object_initial[1].copy(), - ) - self._probe = self._probe_initial.copy() - self.error_iterations = [] - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = (None,) * self._num_sim_measurements - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = (None,) * self._num_sim_measurements - - if gaussian_filter_sigma_m is None: - gaussian_filter_sigma_m = gaussian_filter_sigma_e - - if q_lowpass_m is None: - q_lowpass_m = q_lowpass_e - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = (xp.exp(1j * self._object[0]), self._object[1]) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = (xp.angle(self._object[0]), self._object[1]) - - if a0 == warmup_iter: - self._object = (self._object[0], self._object_initial[1].copy()) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - amps = [] - for amplitudes in self._amplitudes: - amps.append(amplitudes[shuffled_indices[start:end]]) - amplitudes = tuple(amps) - - # forward operator - ( - shifted_probes, - object_patches, - overlap, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - warmup_iteration=a0 < warmup_iter, - use_projection_scheme=use_projection_scheme, - projection_a=projection_a, - projection_b=projection_b, - projection_c=projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - shifted_probes, - self._exit_waves, - warmup_iteration=a0 < warmup_iter, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object[0], - shifted_probes, - overlap[0], - amplitudes[0], - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - warmup_iteration=a0 < warmup_iter, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma_m is not None, - gaussian_filter_sigma_e=gaussian_filter_sigma_e, - gaussian_filter_sigma_m=gaussian_filter_sigma_m, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass_m is not None or q_highpass_m is not None), - q_lowpass_e=q_lowpass_e, - q_lowpass_m=q_lowpass_m, - q_highpass_e=q_highpass_e, - q_highpass_m=q_highpass_m, - butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_inner_iter=tv_denoise_inner_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - ) - - self.error_iterations.append(error.item()) - if store_iterations: - if a0 < warmup_iter: - self.object_iterations.append( - (asnumpy(self._object[0].copy()), None) - ) - else: - self.object_iterations.append( - ( - asnumpy(self._object[0].copy()), - asnumpy(self._object[1].copy()), - ) - ) - self.probe_iterations.append(self.probe_centered) - - # store result - if a0 < warmup_iter: - self.object = (asnumpy(self._object[0]), None) - else: - self.object = (asnumpy(self._object[0]), asnumpy(self._object[1])) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object[0]) - else: - obj = self.object[0] - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (12, 5)) - cmap_e = kwargs.pop("cmap_e", "magma") - cmap_m = kwargs.pop("cmap_m", "PuOr") - - if self._object_type == "complex": - obj_e = np.angle(self.object[0]) - obj_m = self.object[1] - else: - obj_e, obj_m = self.object - - rotated_electrostatic = self._crop_rotate_object_fov(obj_e, padding=padding) - rotated_magnetic = self._crop_rotate_object_fov(obj_m, padding=padding) - rotated_shape = rotated_electrostatic.shape - - min_e = rotated_electrostatic.min() - max_e = rotated_electrostatic.max() - max_m = np.abs(rotated_magnetic).max() - min_m = -max_m - - vmin_e = kwargs.pop("vmin_e", min_e) - vmax_e = kwargs.pop("vmax_e", max_e) - vmin_m = kwargs.pop("vmin_m", min_m) - vmax_m = kwargs.pop("vmax_m", max_m) - - chroma_boost = kwargs.pop("chroma_boost", 1) - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=3, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - 1, - 1, - (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]), - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=3, - nrows=1, - width_ratios=[ - 1, - 1, - (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]), - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=2, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Electrostatic Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_electrostatic, - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Magnetic Object - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - rotated_magnetic, - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed magnetic potential") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - ax = fig.add_subplot(spec[0, 2]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, power=2, chroma_boost=chroma_boost - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - # Electrostatic Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_electrostatic, - extent=extent, - cmap=cmap_e, - vmin=vmin_e, - vmax=vmax_e, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed electrostatic potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed electrostatic phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Magnetic Object - ax = fig.add_subplot(spec[0, 1]) - im = ax.imshow( - rotated_magnetic, - extent=extent, - cmap=cmap_m, - vmin=vmin_m, - vmax=vmax_m, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Reconstructed magnetic potential") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - errors = np.array(self.error_iterations) - ax = fig.add_subplot(spec[1, :]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - raise NotImplementedError() - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - - return self - - @property - def self_consistency_errors(self): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Overlaps - _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap[0]) - - # Normalized mean-squared errors - error = xp.sum( - xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - error /= self._mean_diffraction_intensity - - return asnumpy(error) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[0][start:end] - - # Overlaps - _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap[0]) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped[0]) - else: - projected_cropped_potential = self.object_cropped[0] - - return projected_cropped_potential - - @property - def object_cropped(self): - """Cropped and rotated object""" - - obj_e, obj_m = self._object - obj_e = self._crop_rotate_object_fov(obj_e) - obj_m = self._crop_rotate_object_fov(obj_m) - return (obj_e, obj_m) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py deleted file mode 100644 index 36baac21e..000000000 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ /dev/null @@ -1,2226 +0,0 @@ -""" -Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, -namely (single-slice) ptychography. -""" - -import warnings -from typing import Mapping, Tuple - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable -from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg - -try: - import cupy as cp -except (ImportError, ModuleNotFoundError): - cp = np - -from emdfile import Custom, tqdmnd -from py4DSTEM.datacube import DataCube -from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction -from py4DSTEM.process.phase.utils import ( - ComplexProbe, - fft_shift, - generate_batches, - polar_aliases, - polar_symbols, -) -from py4DSTEM.process.utils import get_CoM, get_shifted_ar - -warnings.simplefilter(action="always", category=UserWarning) - - -class SingleslicePtychographicReconstruction(PtychographicReconstruction): - """ - Iterative Ptychographic Reconstruction Class. - - Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) - Reconstructed probe dimensions : (Sx,Sy) - Reconstructed object dimensions : (Px,Py) - - such that (Sx,Sy) is the region-of-interest (ROI) size of our probe - and (Px,Py) is the padded-object size we position our ROI around in. - - Parameters - ---------- - energy: float - The electron energy of the wave functions in eV - datacube: DataCube - Input 4D diffraction pattern intensities - semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess in mrad - semiangle_cutoff_pixels: float, optional - Semiangle cutoff for the initial probe guess in pixels - rolloff: float, optional - Semiangle rolloff for the initial probe guess - vacuum_probe_intensity: np.ndarray, optional - Vacuum probe to use as intensity aperture for initial probe guess - polar_parameters: dict, optional - Mapping from aberration symbols to their corresponding values. All aberration - magnitudes should be given in Å and angles should be given in radians. - object_padding_px: Tuple[int,int], optional - Pixel dimensions to pad object with - If None, the padding is set to half the probe ROI dimensions - initial_object_guess: np.ndarray, optional - Initial guess for complex-valued object of dimensions (Px,Py) - If None, initialized to 1.0j - initial_probe_guess: np.ndarray, optional - Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, - initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations - initial_scan_positions: np.ndarray, optional - Probe positions in Å for each diffraction intensity - If None, initialized to a grid scan - verbose: bool, optional - If True, class methods will inherit this and print additional information - device: str, optional - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - object_type: str, optional - The object can be reconstructed as a real potential ('potential') or a complex - object ('complex') - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction - name: str, optional - Class name - kwargs: - Provide the aberration coefficients as keyword arguments. - """ - - # Class-specific Metadata - _class_specific_metadata = () - - def __init__( - self, - energy: float, - datacube: DataCube = None, - semiangle_cutoff: float = None, - semiangle_cutoff_pixels: float = None, - rolloff: float = 2.0, - vacuum_probe_intensity: np.ndarray = None, - polar_parameters: Mapping[str, float] = None, - initial_object_guess: np.ndarray = None, - initial_probe_guess: np.ndarray = None, - initial_scan_positions: np.ndarray = None, - object_padding_px: Tuple[int, int] = None, - object_type: str = "complex", - positions_mask: np.ndarray = None, - verbose: bool = True, - device: str = "cpu", - name: str = "ptychographic_reconstruction", - **kwargs, - ): - Custom.__init__(self, name=name) - - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - for key in kwargs.keys(): - if (key not in polar_symbols) and (key not in polar_aliases.keys()): - raise ValueError("{} not a recognized parameter".format(key)) - - self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) - - if polar_parameters is None: - polar_parameters = {} - - polar_parameters.update(kwargs) - self._set_polar_parameters(polar_parameters) - - if object_type != "potential" and object_type != "complex": - raise ValueError( - f"object_type must be either 'potential' or 'complex', not {object_type}" - ) - - self.set_save_defaults() - - # Data - self._datacube = datacube - self._object = initial_object_guess - self._probe = initial_probe_guess - - # Common Metadata - self._vacuum_probe_intensity = vacuum_probe_intensity - self._scan_positions = initial_scan_positions - self._energy = energy - self._semiangle_cutoff = semiangle_cutoff - self._semiangle_cutoff_pixels = semiangle_cutoff_pixels - self._rolloff = rolloff - self._object_type = object_type - self._object_padding_px = object_padding_px - self._positions_mask = positions_mask - self._verbose = verbose - self._device = device - self._preprocessed = False - - # Class-specific Metadata - - def preprocess( - self, - diffraction_intensities_shape: Tuple[int, int] = None, - reshaping_method: str = "fourier", - probe_roi_shape: Tuple[int, int] = None, - dp_mask: np.ndarray = None, - fit_function: str = "plane", - plot_center_of_mass: str = "default", - plot_rotation: bool = True, - maximize_divergence: bool = False, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), - plot_probe_overlaps: bool = True, - force_com_rotation: float = None, - force_com_transpose: float = None, - force_com_shifts: float = None, - force_scan_sampling: float = None, - force_angular_sampling: float = None, - force_reciprocal_sampling: float = None, - object_fov_mask: np.ndarray = None, - crop_patterns: bool = False, - **kwargs, - ): - """ - Ptychographic preprocessing step. - Calls the base class methods: - - _extract_intensities_and_calibrations_from_datacube, - _compute_center_of_mass(), - _solve_CoM_rotation(), - _normalize_diffraction_intensities() - _calculate_scan_positions_in_px() - - Additionally, it initializes an (Px,Py) array of 1.0j - and a complex probe using the specified polar parameters. - - Parameters - ---------- - diffraction_intensities_shape: Tuple[int,int], optional - Pixel dimensions (Qx',Qy') of the resampled diffraction intensities - If None, no resampling of diffraction intenstities is performed - reshaping_method: str, optional - Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) - probe_roi_shape, (int,int), optional - Padded diffraction intensities shape. - If None, no padding is performed - dp_mask: ndarray, optional - Mask for datacube intensities (Qx,Qy) - fit_function: str, optional - 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - plot_center_of_mass: str, optional - If 'default', the corrected CoM arrays will be displayed - If 'all', the computed and fitted CoM arrays will be displayed - plot_rotation: bool, optional - If True, the CoM curl minimization search result will be displayed - maximize_divergence: bool, optional - If True, the divergence of the CoM gradient vector field is maximized - rotation_angles_deg: np.darray, optional - Array of angles in degrees to perform curl minimization over - plot_probe_overlaps: bool, optional - If True, initial probe overlaps scanned over the object will be displayed - force_com_rotation: float (degrees), optional - Force relative rotation angle between real and reciprocal space - force_com_transpose: bool, optional - Force whether diffraction intensities need to be transposed. - force_com_shifts: tuple of ndarrays (CoMx, CoMy) - Amplitudes come from diffraction patterns shifted with - the CoM in the upper left corner for each probe unless - shift is overwritten. - force_scan_sampling: float, optional - Override DataCube real space scan pixel size calibrations, in Angstrom - force_angular_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in mrad - force_reciprocal_sampling: float, optional - Override DataCube reciprocal pixel size calibration, in A^-1 - object_fov_mask: np.ndarray (boolean) - Boolean mask of FOV. Used to calculate additional shrinkage of object - If None, probe_overlap intensity is thresholded - crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns when centering - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - xp = self._xp - asnumpy = self._asnumpy - - # set additional metadata - self._diffraction_intensities_shape = diffraction_intensities_shape - self._reshaping_method = reshaping_method - self._probe_roi_shape = probe_roi_shape - self._dp_mask = dp_mask - - if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run ptycho.attach_datacube(DataCube) first." - ) - ) - - if self._positions_mask is not None and self._positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - self._positions_mask = np.asarray(self._positions_mask, dtype="bool") - - ( - self._datacube, - self._vacuum_probe_intensity, - self._dp_mask, - force_com_shifts, - ) = self._preprocess_datacube_and_vacuum_probe( - self._datacube, - diffraction_intensities_shape=self._diffraction_intensities_shape, - reshaping_method=self._reshaping_method, - probe_roi_shape=self._probe_roi_shape, - vacuum_probe_intensity=self._vacuum_probe_intensity, - dp_mask=self._dp_mask, - com_shifts=force_com_shifts, - ) - - self._intensities = self._extract_intensities_and_calibrations_from_datacube( - self._datacube, - require_calibrations=True, - force_scan_sampling=force_scan_sampling, - force_angular_sampling=force_angular_sampling, - force_reciprocal_sampling=force_reciprocal_sampling, - ) - - ( - self._com_measured_x, - self._com_measured_y, - self._com_fitted_x, - self._com_fitted_y, - self._com_normalized_x, - self._com_normalized_y, - ) = self._calculate_intensities_center_of_mass( - self._intensities, - dp_mask=self._dp_mask, - fit_function=fit_function, - com_shifts=force_com_shifts, - ) - - ( - self._rotation_best_rad, - self._rotation_best_transpose, - self._com_x, - self._com_y, - self.com_x, - self.com_y, - ) = self._solve_for_center_of_mass_relative_rotation( - self._com_measured_x, - self._com_measured_y, - self._com_normalized_x, - self._com_normalized_y, - rotation_angles_deg=rotation_angles_deg, - plot_rotation=plot_rotation, - plot_center_of_mass=plot_center_of_mass, - maximize_divergence=maximize_divergence, - force_com_rotation=force_com_rotation, - force_com_transpose=force_com_transpose, - **kwargs, - ) - - ( - self._amplitudes, - self._mean_diffraction_intensity, - ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, - crop_patterns, - self._positions_mask, - ) - - # explicitly delete namespace - self._num_diffraction_patterns = self._amplitudes.shape[0] - self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) - del self._intensities - - self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) - - # handle semiangle specified in pixels - if self._semiangle_cutoff_pixels: - self._semiangle_cutoff = ( - self._semiangle_cutoff_pixels * self._angular_sampling[0] - ) - - # Object Initialization - if self._object is None: - pad_x = self._object_padding_px[0][1] - pad_y = self._object_padding_px[1][1] - p, q = np.round(np.max(self._positions_px, axis=0)) - p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - "int" - ) - q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - "int" - ) - if self._object_type == "potential": - self._object = xp.zeros((p, q), dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.ones((p, q), dtype=xp.complex64) - else: - if self._object_type == "potential": - self._object = xp.asarray(self._object, dtype=xp.float32) - elif self._object_type == "complex": - self._object = xp.asarray(self._object, dtype=xp.complex64) - - self._object_initial = self._object.copy() - self._object_type_initial = self._object_type - self._object_shape = self._object.shape - - self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 - self._positions_px_com = xp.mean(self._positions_px, axis=0) - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - self._positions_px_initial = self._positions_px.copy() - self._positions_initial = self._positions_px_initial.copy() - self._positions_initial[:, 0] *= self.sampling[0] - self._positions_initial[:, 1] *= self.sampling[1] - - # Vectorized Patches - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Probe Initialization - if self._probe is None: - if self._vacuum_probe_intensity is not None: - self._semiangle_cutoff = np.inf - self._vacuum_probe_intensity = xp.asarray( - self._vacuum_probe_intensity, dtype=xp.float32 - ) - probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, - device=self._device, - ) - self._vacuum_probe_intensity = get_shifted_ar( - self._vacuum_probe_intensity, - -probe_x0, - -probe_y0, - bilinear=True, - device=self._device, - ) - if crop_patterns: - self._vacuum_probe_intensity = self._vacuum_probe_intensity[ - self._crop_mask - ].reshape(self._region_of_interest_shape) - - self._probe = ( - ComplexProbe( - gpts=self._region_of_interest_shape, - sampling=self.sampling, - energy=self._energy, - semiangle_cutoff=self._semiangle_cutoff, - rolloff=self._rolloff, - vacuum_probe_intensity=self._vacuum_probe_intensity, - parameters=self._polar_parameters, - device=self._device, - ) - .build() - ._array - ) - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) - - else: - if isinstance(self._probe, ComplexProbe): - if self._probe._gpts != self._region_of_interest_shape: - raise ValueError() - if hasattr(self._probe, "_array"): - self._probe = self._probe._array - else: - self._probe._xp = xp - self._probe = self._probe.build()._array - - # Normalize probe to match mean diffraction intensity - probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2) - self._probe *= xp.sqrt( - self._mean_diffraction_intensity / probe_intensity - ) - else: - self._probe = xp.asarray(self._probe, dtype=xp.complex64) - - self._probe_initial = self._probe.copy() - self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - self._known_aberrations_array = ComplexProbe( - energy=self._energy, - gpts=self._region_of_interest_shape, - sampling=self.sampling, - parameters=self._polar_parameters, - device=self._device, - )._evaluate_ctf() - - # overlaps - shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) - probe_intensities = xp.abs(shifted_probes) ** 2 - probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) - probe_overlap = self._gaussian_filter(probe_overlap, 1.0) - - if object_fov_mask is None: - self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) - else: - self._object_fov_mask = np.asarray(object_fov_mask) - self._object_fov_mask_inverse = np.invert(self._object_fov_mask) - - if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (9, 4)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - # initial probe - complex_probe_rgb = Complex2RGB( - self.probe_centered, - power=2, - chroma_boost=chroma_boost, - ) - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - ) - - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1, chroma_boost=chroma_boost) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial probe intensity") - - ax2.imshow( - asnumpy(probe_overlap), - extent=extent, - cmap="gray", - ) - ax2.scatter( - self.positions[:, 1], - self.positions[:, 0], - s=2.5, - color=(1, 0, 0, 1), - ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_xlim((extent[0], extent[1])) - ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object field of view") - - fig.tight_layout() - - self._preprocessed = True - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _overlap_projection(self, current_object, current_probe): - """ - Ptychographic overlap projection method. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - shifted_probes * object_patches - """ - - xp = self._xp - - shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp) - - if self._object_type == "potential": - complex_object = xp.exp(1j * current_object) - else: - complex_object = current_object - - object_patches = complex_object[ - self._vectorized_patch_indices_row, self._vectorized_patch_indices_col - ] - - overlap = shifted_probes * object_patches - - return shifted_probes, object_patches, overlap - - def _gradient_descent_fourier_projection(self, amplitudes, overlap): - """ - Ptychographic fourier projection method for GD method. - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - - Returns - -------- - exit_waves:np.ndarray - Difference between modified and estimated exit waves - error: float - Reconstruction error - """ - - xp = self._xp - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) - modified_overlap = xp.fft.ifft2(fourier_modified_overlap) - - exit_waves = modified_overlap - overlap - - return exit_waves, error - - def _projection_sets_fourier_projection( - self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c - ): - """ - Ptychographic fourier projection method for DM_AP and RAAR methods. - Generalized projection using three parameters: a,b,c - - DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha - DM: DM_AP(1.0), AP: DM_AP(0.0) - - RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 - DM : RAAR(1.0) - - RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 - DM: RRR(1.0) - - SUPERFLIP : a = 0, b = 1, c = 2 - - Parameters - -------- - amplitudes: np.ndarray - Normalized measured amplitudes - overlap: np.ndarray - object * probe overlap - exit_waves: np.ndarray - previously estimated exit waves - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - xp = self._xp - projection_x = 1 - projection_a - projection_b - projection_y = 1 - projection_c - - if exit_waves is None: - exit_waves = overlap.copy() - - fourier_overlap = xp.fft.fft2(overlap) - error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2) - - factor_to_be_projected = projection_c * overlap + projection_y * exit_waves - fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) - - fourier_projected_factor = amplitudes * xp.exp( - 1j * xp.angle(fourier_projected_factor) - ) - projected_factor = xp.fft.ifft2(fourier_projected_factor) - - exit_waves = ( - projection_x * exit_waves - + projection_a * overlap - + projection_b * projected_factor - ) - - return exit_waves, error - - def _forward( - self, - current_object, - current_probe, - amplitudes, - exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ): - """ - Ptychographic forward operator. - Calls _overlap_projection() and the appropriate _fourier_projection(). - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - amplitudes: np.ndarray - Normalized measured amplitudes - exit_waves: np.ndarray - previously estimated exit waves - use_projection_scheme: bool, - If True, use generalized projection update - projection_a: float - projection_b: float - projection_c: float - - Returns - -------- - shifted_probes:np.ndarray - fractionally-shifted probes - object_patches: np.ndarray - Patched object view - overlap: np.ndarray - object * probe overlap - exit_waves:np.ndarray - Updated exit_waves - error: float - Reconstruction error - """ - - shifted_probes, object_patches, overlap = self._overlap_projection( - current_object, current_probe - ) - if use_projection_scheme: - exit_waves, error = self._projection_sets_fourier_projection( - amplitudes, - overlap, - exit_waves, - projection_a, - projection_b, - projection_c, - ) - - else: - exit_waves, error = self._gradient_descent_fourier_projection( - amplitudes, overlap - ) - - return shifted_probes, object_patches, overlap, exit_waves, error - - def _gradient_descent_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for GD method. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object += step_size * ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe += step_size * ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _projection_sets_adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ): - """ - Ptychographic adjoint operator for DM_AP and RAAR methods. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - xp = self._xp - - probe_normalization = self._sum_overlapping_patches_bincounts( - xp.abs(shifted_probes) ** 2 - ) - probe_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * probe_normalization) ** 2 - + (normalization_min * xp.max(probe_normalization)) ** 2 - ) - - if self._object_type == "potential": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(object_patches) - * xp.conj(shifted_probes) - * exit_waves - ) - ) - * probe_normalization - ) - elif self._object_type == "complex": - current_object = ( - self._sum_overlapping_patches_bincounts( - xp.conj(shifted_probes) * exit_waves - ) - * probe_normalization - ) - - if not fix_probe: - object_normalization = xp.sum( - (xp.abs(object_patches) ** 2), - axis=0, - ) - object_normalization = 1 / xp.sqrt( - 1e-16 - + ((1 - normalization_min) * object_normalization) ** 2 - + (normalization_min * xp.max(object_normalization)) ** 2 - ) - - current_probe = ( - xp.sum( - xp.conj(object_patches) * exit_waves, - axis=0, - ) - * object_normalization - ) - - return current_object, current_probe - - def _adjoint( - self, - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - use_projection_scheme: bool, - step_size: float, - normalization_min: float, - fix_probe: bool, - ): - """ - Ptychographic adjoint operator. - Computes object and probe update steps. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - object_patches: np.ndarray - Patched object view - shifted_probes:np.ndarray - fractionally-shifted probes - exit_waves:np.ndarray - Updated exit_waves - use_projection_scheme: bool, - If True, use generalized projection update - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - fix_probe: bool, optional - If True, probe will not be updated - - Returns - -------- - updated_object: np.ndarray - Updated object estimate - updated_probe: np.ndarray - Updated probe estimate - """ - - if use_projection_scheme: - current_object, current_probe = self._projection_sets_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - normalization_min, - fix_probe, - ) - else: - current_object, current_probe = self._gradient_descent_adjoint( - current_object, - current_probe, - object_patches, - shifted_probes, - exit_waves, - step_size, - normalization_min, - fix_probe, - ) - - return current_object, current_probe - - def _constraints( - self, - current_object, - current_probe, - current_positions, - pure_phase_object, - fix_com, - fit_probe_aberrations, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - constrain_probe_amplitude, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - fix_probe_aperture, - initial_probe_aperture, - fix_positions, - global_affine_transformation, - gaussian_filter, - gaussian_filter_sigma, - butterworth_filter, - q_lowpass, - q_highpass, - butterworth_order, - tv_denoise, - tv_denoise_weight, - tv_denoise_inner_iter, - object_positivity, - shrinkage_rad, - object_mask, - ): - """ - Ptychographic constraints operator. - - Parameters - -------- - current_object: np.ndarray - Current object estimate - current_probe: np.ndarray - Current probe estimate - current_positions: np.ndarray - Current positions estimate - pure_phase_object: bool - If True, object amplitude is set to unity - fix_com: bool - If True, probe CoM is fixed to the center - fit_probe_aberrations: bool - If True, fits the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - constrain_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude: bool - If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_probe_aperture: bool, - If True, probe Fourier amplitude is replaced by initial probe aperture. - initial_probe_aperture: np.ndarray, - Initial probe aperture to use in replacing probe Fourier amplitude. - fix_positions: bool - If True, positions are not updated - gaussian_filter: bool - If True, applies real-space gaussian filter - gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A - butterworth_filter: bool - If True, applies high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise: bool - If True, applies TV denoising on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool - If True, clips negative potential values - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - object_mask: np.ndarray (boolean) - If not None, used to calculate additional shrinkage using masked-mean of object - - Returns - -------- - constrained_object: np.ndarray - Constrained object estimate - constrained_probe: np.ndarray - Constrained probe estimate - constrained_positions: np.ndarray - Constrained positions estimate - """ - - if gaussian_filter: - current_object = self._object_gaussian_constraint( - current_object, gaussian_filter_sigma, pure_phase_object - ) - - if butterworth_filter: - current_object = self._object_butterworth_constraint( - current_object, - q_lowpass, - q_highpass, - butterworth_order, - ) - - if tv_denoise: - current_object = self._object_denoise_tv_pylops( - current_object, tv_denoise_weight, tv_denoise_inner_iter - ) - - if shrinkage_rad > 0.0 or object_mask is not None: - current_object = self._object_shrinkage_constraint( - current_object, - shrinkage_rad, - object_mask, - ) - - if self._object_type == "complex": - current_object = self._object_threshold_constraint( - current_object, pure_phase_object - ) - elif object_positivity: - current_object = self._object_positivity_constraint(current_object) - - if fix_com: - current_probe = self._probe_center_of_mass_constraint(current_probe) - - if fix_probe_aperture: - current_probe = self._probe_aperture_constraint( - current_probe, - initial_probe_aperture, - ) - elif constrain_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( - current_probe, - constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity, - ) - - if fit_probe_aberrations: - current_probe = self._probe_aberration_fitting_constraint( - current_probe, - fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order, - ) - - if constrain_probe_amplitude: - current_probe = self._probe_amplitude_constraint( - current_probe, - constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width, - ) - - if not fix_positions: - current_positions = self._positions_center_of_mass_constraint( - current_positions - ) - - if global_affine_transformation: - current_positions = self._positions_affine_transformation_constraint( - self._positions_px_initial, current_positions - ) - - return current_object, current_probe, current_positions - - def reconstruct( - self, - max_iter: int = 64, - reconstruction_method: str = "gradient-descent", - reconstruction_parameter: float = 1.0, - reconstruction_parameter_a: float = None, - reconstruction_parameter_b: float = None, - reconstruction_parameter_c: float = None, - max_batch_size: int = None, - seed_random: int = None, - step_size: float = 0.5, - normalization_min: float = 1, - positions_step_size: float = 0.9, - pure_phase_object_iter: int = 0, - fix_com: bool = True, - fix_probe_iter: int = 0, - fix_probe_aperture_iter: int = 0, - constrain_probe_amplitude_iter: int = 0, - constrain_probe_amplitude_relative_radius: float = 0.5, - constrain_probe_amplitude_relative_width: float = 0.05, - constrain_probe_fourier_amplitude_iter: int = 0, - constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, - constrain_probe_fourier_amplitude_constant_intensity: bool = False, - fix_positions_iter: int = np.inf, - constrain_position_distance: float = None, - global_affine_transformation: bool = True, - gaussian_filter_sigma: float = None, - gaussian_filter_iter: int = np.inf, - fit_probe_aberrations_iter: int = 0, - fit_probe_aberrations_max_angular_order: int = 4, - fit_probe_aberrations_max_radial_order: int = 4, - butterworth_filter_iter: int = np.inf, - q_lowpass: float = None, - q_highpass: float = None, - butterworth_order: float = 2, - tv_denoise_iter: int = np.inf, - tv_denoise_weight: float = None, - tv_denoise_inner_iter: float = 40, - object_positivity: bool = True, - shrinkage_rad: float = 0.0, - fix_potential_baseline: bool = True, - switch_object_iter: int = np.inf, - store_iterations: bool = False, - progress_bar: bool = True, - reset: bool = None, - ): - """ - Ptychographic reconstruction main method. - - Parameters - -------- - max_iter: int, optional - Maximum number of iterations to run - reconstruction_method: str, optional - Specifies which reconstruction algorithm to use, one of: - "generalized-projections", - "DM_AP" (or "difference-map_alternating-projections"), - "RAAR" (or "relaxed-averaged-alternating-reflections"), - "RRR" (or "relax-reflect-reflect"), - "SUPERFLIP" (or "charge-flipping"), or - "GD" (or "gradient_descent") - reconstruction_parameter: float, optional - Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter_a: float, optional - Reconstruction parameter a for reconstruction_method='generalized-projections'. - reconstruction_parameter_b: float, optional - Reconstruction parameter b for reconstruction_method='generalized-projections'. - reconstruction_parameter_c: float, optional - Reconstruction parameter c for reconstruction_method='generalized-projections'. - max_batch_size: int, optional - Max number of probes to update at once - seed_random: int, optional - Seeds the random number generator, only applicable when max_batch_size is not None - step_size: float, optional - Update step size - normalization_min: float, optional - Probe normalization minimum as a fraction of the maximum overlap intensity - positions_step_size: float, optional - Positions update step size - pure_phase_object_iter: int, optional - Number of iterations where object amplitude is set to unity - fix_com: bool, optional - If True, fixes center of mass of probe - fix_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_aperture_iter: int, optional - Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate - constrain_probe_amplitude_iter: int, optional - Number of iterations to run while constraining the real-space probe with a top-hat support. - constrain_probe_amplitude_relative_radius: float - Relative location of top-hat inflection point, between 0 and 0.5 - constrain_probe_amplitude_relative_width: float - Relative width of top-hat sigmoid, between 0 and 0.5 - constrain_probe_fourier_amplitude_iter: int, optional - Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. - constrain_probe_fourier_amplitude_max_width_pixels: float - Maximum pixel width of fitted sigmoid functions. - constrain_probe_fourier_amplitude_constant_intensity: bool - If True, the probe aperture is additionally constrained to a constant intensity. - fix_positions_iter: int, optional - Number of iterations to run with fixed positions before updating positions estimate - constrain_position_distance: float, optional - Distance to constrain position correction within original - field of view in A - global_affine_transformation: bool, optional - If True, positions are assumed to be a global affine transform from initial scan - gaussian_filter_sigma: float, optional - Standard deviation of gaussian kernel in A - gaussian_filter_iter: int, optional - Number of iterations to run using object smoothness constraint - fit_probe_aberrations_iter: int, optional - Number of iterations to run while fitting the probe aberrations to a low-order expansion - fit_probe_aberrations_max_angular_order: bool - Max angular order of probe aberrations basis functions - fit_probe_aberrations_max_radial_order: bool - Max radial order of probe aberrations basis functions - butterworth_filter_iter: int, optional - Number of iterations to run using high-pass butteworth filter - q_lowpass: float - Cut-off frequency in A^-1 for low-pass butterworth filter - q_highpass: float - Cut-off frequency in A^-1 for high-pass butterworth filter - butterworth_order: float - Butterworth filter order. Smaller gives a smoother filter - tv_denoise_iter: int, optional - Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float - Denoising weight. The greater `weight`, the more denoising. - tv_denoise_inner_iter: float - Number of iterations to run in inner loop of TV denoising - object_positivity: bool, optional - If True, forces object to be positive - shrinkage_rad: float - Phase shift in radians to be subtracted from the potential at each iteration - fix_potential_baseline: bool - If true, the potential mean outside the FOV is forced to zero at each iteration - switch_object_iter: int, optional - Iteration to switch object type between 'complex' and 'potential' or between - 'potential' and 'complex' - store_iterations: bool, optional - If True, reconstructed objects and probes are stored at each iteration - progress_bar: bool, optional - If True, reconstruction progress is displayed - reset: bool, optional - If True, previous reconstructions are ignored - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - asnumpy = self._asnumpy - xp = self._xp - - # Reconstruction method - - if reconstruction_method == "generalized-projections": - if ( - reconstruction_parameter_a is None - or reconstruction_parameter_b is None - or reconstruction_parameter_c is None - ): - raise ValueError( - ( - "reconstruction_parameter_a/b/c must all be specified " - "when using reconstruction_method='generalized-projections'." - ) - ) - - use_projection_scheme = True - projection_a = reconstruction_parameter_a - projection_b = reconstruction_parameter_b - projection_c = reconstruction_parameter_c - step_size = None - elif ( - reconstruction_method == "DM_AP" - or reconstruction_method == "difference-map_alternating-projections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = 1 - projection_c = 1 + reconstruction_parameter - step_size = None - elif ( - reconstruction_method == "RAAR" - or reconstruction_method == "relaxed-averaged-alternating-reflections" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: - raise ValueError("reconstruction_parameter must be between 0-1.") - - use_projection_scheme = True - projection_a = 1 - 2 * reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "RRR" - or reconstruction_method == "relax-reflect-reflect" - ): - if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: - raise ValueError("reconstruction_parameter must be between 0-2.") - - use_projection_scheme = True - projection_a = -reconstruction_parameter - projection_b = reconstruction_parameter - projection_c = 2 - step_size = None - elif ( - reconstruction_method == "SUPERFLIP" - or reconstruction_method == "charge-flipping" - ): - use_projection_scheme = True - projection_a = 0 - projection_b = 1 - projection_c = 2 - reconstruction_parameter = None - step_size = None - elif ( - reconstruction_method == "GD" or reconstruction_method == "gradient-descent" - ): - use_projection_scheme = False - projection_a = None - projection_b = None - projection_c = None - reconstruction_parameter = None - else: - raise ValueError( - ( - "reconstruction_method must be one of 'generalized-projections', " - "'DM_AP' (or 'difference-map_alternating-projections'), " - "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " - "'RRR' (or 'relax-reflect-reflect'), " - "'SUPERFLIP' (or 'charge-flipping'), " - f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." - ) - ) - - if self._verbose: - if switch_object_iter > max_iter: - first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - else: - switch_object_type = ( - "complex" if self._object_type == "potential" else "potential" - ) - first_line = ( - f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " - f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " - ) - if max_batch_size is not None: - if use_projection_scheme: - raise ValueError( - ( - "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " - "Use reconstruction_method='GD' or set max_batch_size=None." - ) - ) - else: - print( - ( - first_line + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}, " - f"in batches of max {max_batch_size} measurements." - ) - ) - - else: - if reconstruction_parameter is not None: - if np.array(reconstruction_parameter).shape == (3,): - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." - ) - ) - else: - if step_size is not None: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min}." - ) - ) - else: - print( - ( - first_line - + f"with the {reconstruction_method} algorithm, " - f"with normalization_min: {normalization_min} and step _size: {step_size}." - ) - ) - - # Batching - shuffled_indices = np.arange(self._num_diffraction_patterns) - unshuffled_indices = np.zeros_like(shuffled_indices) - - if max_batch_size is not None: - xp.random.seed(seed_random) - else: - max_batch_size = self._num_diffraction_patterns - - # initialization - if store_iterations and (not hasattr(self, "object_iterations") or reset): - self.object_iterations = [] - self.probe_iterations = [] - - if reset: - self.error_iterations = [] - self._object = self._object_initial.copy() - self._probe = self._probe_initial.copy() - self._positions_px = self._positions_px_initial.copy() - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - self._exit_waves = None - self._object_type = self._object_type_initial - if hasattr(self, "_tf"): - del self._tf - elif reset is None: - if hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - else: - self.error_iterations = [] - self._exit_waves = None - - # main loop - for a0 in tqdmnd( - max_iter, - desc="Reconstructing object and probe", - unit=" iter", - disable=not progress_bar, - ): - error = 0.0 - - if a0 == switch_object_iter: - if self._object_type == "potential": - self._object_type = "complex" - self._object = xp.exp(1j * self._object) - elif self._object_type == "complex": - self._object_type = "potential" - self._object = xp.angle(self._object) - - # randomize - if not use_projection_scheme: - np.random.shuffle(shuffled_indices) - unshuffled_indices[shuffled_indices] = np.arange( - self._num_diffraction_patterns - ) - positions_px = self._positions_px.copy()[shuffled_indices] - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[shuffled_indices[start:end]] - - # forward operator - ( - shifted_probes, - object_patches, - overlap, - self._exit_waves, - batch_error, - ) = self._forward( - self._object, - self._probe, - amplitudes, - self._exit_waves, - use_projection_scheme, - projection_a, - projection_b, - projection_c, - ) - - # adjoint operator - self._object, self._probe = self._adjoint( - self._object, - self._probe, - object_patches, - shifted_probes, - self._exit_waves, - use_projection_scheme=use_projection_scheme, - step_size=step_size, - normalization_min=normalization_min, - fix_probe=a0 < fix_probe_iter, - ) - - # position correction - if a0 >= fix_positions_iter: - positions_px[start:end] = self._position_correction( - self._object, - shifted_probes, - overlap, - amplitudes, - self._positions_px, - positions_step_size, - constrain_position_distance, - ) - - error += batch_error - - # Normalize Error - error /= self._mean_diffraction_intensity * self._num_diffraction_patterns - - # constraints - self._positions_px = positions_px.copy()[unshuffled_indices] - self._object, self._probe, self._positions_px = self._constraints( - self._object, - self._probe, - self._positions_px, - fix_com=fix_com and a0 >= fix_probe_iter, - constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, - constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, - constrain_probe_fourier_amplitude=a0 - < constrain_probe_fourier_amplitude_iter - and a0 >= fix_probe_iter, - constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, - constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, - fit_probe_aberrations=a0 < fit_probe_aberrations_iter - and a0 >= fix_probe_iter, - fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, - fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, - fix_probe_aperture=a0 < fix_probe_aperture_iter, - initial_probe_aperture=self._probe_initial_aperture, - fix_positions=a0 < fix_positions_iter, - global_affine_transformation=global_affine_transformation, - gaussian_filter=a0 < gaussian_filter_iter - and gaussian_filter_sigma is not None, - gaussian_filter_sigma=gaussian_filter_sigma, - butterworth_filter=a0 < butterworth_filter_iter - and (q_lowpass is not None or q_highpass is not None), - q_lowpass=q_lowpass, - q_highpass=q_highpass, - butterworth_order=butterworth_order, - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_inner_iter=tv_denoise_inner_iter, - object_positivity=object_positivity, - shrinkage_rad=shrinkage_rad, - object_mask=self._object_fov_mask_inverse - if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 - else None, - pure_phase_object=a0 < pure_phase_object_iter - and self._object_type == "complex", - ) - - self.error_iterations.append(error.item()) - if store_iterations: - self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(self.probe_centered) - - # store result - self.object = asnumpy(self._object) - self.probe = self.probe_centered - self.error = error.item() - - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - return self - - def _visualize_last_iteration_figax( - self, - fig, - object_ax, - convergence_ax: None, - cbar: bool, - padding: int = 0, - **kwargs, - ): - """ - Displays last reconstructed object on a given fig/ax. - - Parameters - -------- - fig: Figure - Matplotlib figure object_ax lives in - object_ax: Axes - Matplotlib axes to plot reconstructed object in - convergence_ax: Axes, optional - Matplotlib axes to plot convergence plot in - cbar: bool, optional - If true, displays a colorbar - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - cmap = kwargs.pop("cmap", "magma") - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - im = object_ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - if cbar: - divider = make_axes_locatable(object_ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if convergence_ax is not None and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = self.error_iterations - convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) - - def _visualize_last_iteration( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - padding: int, - **kwargs, - ): - """ - Displays last reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool, optional - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - figsize = kwargs.pop("figsize", (8, 5)) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - if self._object_type == "complex": - obj = np.angle(self.object) - else: - obj = self.object - - rotated_object = self._crop_rotate_object_fov(obj, padding=padding) - rotated_shape = rotated_object.shape - - extent = [ - 0, - self.sampling[1] * rotated_shape[1], - self.sampling[0] * rotated_shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=2, - height_ratios=[4, 1], - hspace=0.15, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec( - ncols=2, - nrows=1, - width_ratios=[ - (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), - 1, - ], - wspace=0.35, - ) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - if plot_probe or plot_fourier_probe: - # Object - ax = fig.add_subplot(spec[0, 0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - # Probe - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - ax = fig.add_subplot(spec[0, 1]) - if plot_fourier_probe: - if remove_initial_probe_aberrations: - probe_array = self.probe_fourier_residual - else: - probe_array = self.probe_fourier - - probe_array = Complex2RGB( - probe_array, - chroma_boost=chroma_boost, - ) - - ax.set_title("Reconstructed Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - else: - probe_array = Complex2RGB( - self.probe, - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title("Reconstructed probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) - - else: - ax = fig.add_subplot(spec[0]) - im = ax.imshow( - rotated_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if self._object_type == "potential": - ax.set_title("Reconstructed object potential") - elif self._object_type == "complex": - ax.set_title("Reconstructed object phase") - - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) - - if plot_convergence and hasattr(self, "error_iterations"): - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - errors = np.array(self.error_iterations) - if plot_probe: - ax = fig.add_subplot(spec[1, :]) - else: - ax = fig.add_subplot(spec[1]) - ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration number") - ax.yaxis.tick_right() - - fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") - spec.tight_layout(fig) - - def _visualize_all_iterations( - self, - fig, - cbar: bool, - plot_convergence: bool, - plot_probe: bool, - plot_fourier_probe: bool, - remove_initial_probe_aberrations: bool, - iterations_grid: Tuple[int, int], - padding: int, - **kwargs, - ): - """ - Displays all reconstructed object and probe iterations. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed complex probe is displayed - plot_fourier_probe: bool - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - """ - asnumpy = self._asnumpy - - if not hasattr(self, "object_iterations"): - raise ValueError( - ( - "Object and probe iterations were not saved during reconstruction. " - "Please re-run using store_iterations=True." - ) - ) - - if iterations_grid == "auto": - num_iter = len(self.error_iterations) - - if num_iter == 1: - return self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - padding=padding, - **kwargs, - ) - elif plot_probe or plot_fourier_probe: - iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) - else: - iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) - else: - if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: - raise ValueError() - - auto_figsize = ( - (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) - if plot_convergence - else (3 * iterations_grid[1], 3 * iterations_grid[0]) - ) - figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "magma") - - chroma_boost = kwargs.pop("chroma_boost", 1) - - errors = np.array(self.error_iterations) - - objects = [] - object_type = [] - - for obj in self.object_iterations: - if np.iscomplexobj(obj): - obj = np.angle(obj) - object_type.append("phase") - else: - object_type.append("potential") - objects.append(self._crop_rotate_object_fov(obj, padding=padding)) - - if plot_probe or plot_fourier_probe: - total_grids = (np.prod(iterations_grid) / 2).astype("int") - probes = self.probe_iterations - else: - total_grids = np.prod(iterations_grid) - max_iter = len(objects) - 1 - grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) - - extent = [ - 0, - self.sampling[1] * objects[0].shape[1], - self.sampling[0] * objects[0].shape[0], - 0, - ] - - if plot_fourier_probe: - probe_extent = [ - -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, - self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, - ] - elif plot_probe: - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] - - if plot_convergence: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) - else: - spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) - else: - if plot_probe or plot_fourier_probe: - spec = GridSpec(ncols=1, nrows=2) - else: - spec = GridSpec(ncols=1, nrows=1) - - if fig is None: - fig = plt.figure(figsize=figsize) - - grid = ImageGrid( - fig, - spec[0], - nrows_ncols=(1, iterations_grid[1]) - if (plot_probe or plot_fourier_probe) - else iterations_grid, - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - im = ax.imshow( - objects[grid_range[n]], - extent=extent, - cmap=cmap, - **kwargs, - ) - ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: - grid.cbar_axes[n].colorbar(im) - - if plot_probe or plot_fourier_probe: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - - grid = ImageGrid( - fig, - spec[1], - nrows_ncols=(1, iterations_grid[1]), - axes_pad=(0.75, 0.5) if cbar else 0.5, - cbar_mode="each" if cbar else None, - cbar_pad="2.5%" if cbar else None, - ) - - for n, ax in enumerate(grid): - if plot_fourier_probe: - probe_array = asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]], - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - ) - - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") - ax.set_ylabel("kx [mrad]") - ax.set_xlabel("ky [mrad]") - - else: - probe_array = Complex2RGB( - probes[grid_range[n]], - power=2, - chroma_boost=chroma_boost, - ) - ax.set_title(f"Iter: {grid_range[n]} probe intensity") - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - - im = ax.imshow( - probe_array, - extent=probe_extent, - ) - - if cbar: - add_colorbar_arg( - grid.cbar_axes[n], - chroma_boost=chroma_boost, - ) - - if plot_convergence: - kwargs.pop("vmin", None) - kwargs.pop("vmax", None) - if plot_probe: - ax2 = fig.add_subplot(spec[2]) - else: - ax2 = fig.add_subplot(spec[1]) - ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) - ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration number") - ax2.yaxis.tick_right() - - spec.tight_layout(fig) - - def visualize( - self, - fig=None, - iterations_grid: Tuple[int, int] = None, - plot_convergence: bool = True, - plot_probe: bool = True, - plot_fourier_probe: bool = False, - remove_initial_probe_aberrations: bool = False, - cbar: bool = True, - padding: int = 0, - **kwargs, - ): - """ - Displays reconstructed object and probe. - - Parameters - -------- - fig: Figure - Matplotlib figure to place Gridspec in - plot_convergence: bool, optional - If true, the normalized mean squared error (NMSE) plot is displayed - iterations_grid: Tuple[int,int] - Grid dimensions to plot reconstruction iterations - cbar: bool, optional - If true, displays a colorbar - plot_probe: bool - If true, the reconstructed probe intensity is also displayed - plot_fourier_probe: bool, optional - If true, the reconstructed complex Fourier probe is displayed - remove_initial_probe_aberrations: bool, optional - If true, when plotting fourier probe, removes initial probe - to visualize changes - padding : int, optional - Pixels to pad by post rotating-cropping object - - Returns - -------- - self: PtychographicReconstruction - Self to accommodate chaining - """ - - if iterations_grid is None: - self._visualize_last_iteration( - fig=fig, - plot_convergence=plot_convergence, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - cbar=cbar, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - padding=padding, - **kwargs, - ) - else: - self._visualize_all_iterations( - fig=fig, - plot_convergence=plot_convergence, - iterations_grid=iterations_grid, - plot_probe=plot_probe, - plot_fourier_probe=plot_fourier_probe, - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - cbar=cbar, - padding=padding, - **kwargs, - ) - - return self diff --git a/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py new file mode 100644 index 000000000..c9efae806 --- /dev/null +++ b/py4DSTEM/process/phase/magnetic_ptychographic_tomography.py @@ -0,0 +1,1619 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely magnetic ptychographic tomography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, + Object3DMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + project_vector_field_divergence_periodic_3D, +) + + +class MagneticPtychographicTomography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object3DConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + Object3DMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Magnetic Ptychographic Tomography Reconstruction Class. + + List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (Px,Py,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py,Py) is the padded-object electrostatic potential volume, + where x-axis is the tilt. + + Parameters + ---------- + datacube: List of DataCubes + Input list of 4D diffraction pattern intensities for different tilts + energy: float + The electron energy of the wave functions in eV + num_slices: int + Number of super-slices to use in the forward model + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py,Py) + If None, initialized to 1.0 + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: list of np.ndarray, optional + Probe positions in Å for each diffraction intensity per tilt + If None, initialized to a grid scan centered along tilt axis + positions_offset_ang: list of np.ndarray, optional + Offset of positions in A + verbose: bool, optional + If True, class methods will inherit this and print additional information + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + name: str, optional + Class name + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_slices", + "_tilt_orientation_matrices", + "_num_measurements", + ) + + def __init__( + self, + energy: float, + num_slices: int, + tilt_orientation_matrices: Sequence[np.ndarray], + datacube: Sequence[DataCube] = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + object_type: str = "potential", + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: Sequence[np.ndarray] = None, + positions_offset_ang: Sequence[np.ndarray] = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "magnetic-ptychographic-tomography_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + num_tilts = len(tilt_orientation_matrices) + if initial_scan_positions is None: + initial_scan_positions = [None] * num_tilts + + if object_type != "potential": + raise NotImplementedError() + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_slices = num_slices + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) + self._num_measurements = num_tilts + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_probe_overlaps: bool = True, + rotation_real_space_degrees: float = None, + diffraction_patterns_rotate_degrees: float = None, + diffraction_patterns_transpose: bool = None, + force_com_shifts: Sequence[float] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + progress_bar: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = True, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + + Additionally, it initializes an (Px,Py, Py) array of 1.0 + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + rotation_real_space_degrees: float (degrees), optional + In plane rotation around z axis between x axis and tilt axis in + real space (forced to be in xy plane) + diffraction_patterns_rotate_degrees: float, optional + Relative rotation angle between real and reciprocal space + diffraction_patterns_transpose: bool, optional + Whether diffraction intensities need to be transposed. + force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. One tuple per tilt. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: OverlapTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if force_com_measured is None: + force_com_measured = [None] * self._num_measurements + + if self._positions_offset_ang is None: + self._positions_offset_ang = [None] * self._num_measurements + + self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) + self._rotation_best_transpose = diffraction_patterns_transpose + + if progress_bar: + # turn off verbosity to play nice with tqdm + verbose = self._verbose + self._verbose = False + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="tilt", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first tilt + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured[index], + ) + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + self._positions_offset_ang[index], + ) + + if progress_bar: + # reset verbosity + self._verbose = verbose + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # initialize object + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + main_tilt_axis=None, + ) + + if self._object is None: + self._object = xp.full((4,) + obj.shape, obj) + else: + self._object = obj + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + self._num_voxels = self._object.shape[1] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + thickness_h = self._object_shape[1] * self.sampling[1] + thickness_v = self._object_shape[0] * self.sampling[0] + thickness = max(thickness_h, thickness_v) + + self._slice_thicknesses = np.tile( + thickness / self._num_slices, self._num_slices - 1 + ) + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + if object_fov_mask is not True: + raise NotImplementedError() + else: + self._object_fov_mask = np.full(self._object_shape, True) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + probe_overlap = asnumpy(probe_overlap) + + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probes_all[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, + chroma_boost=chroma_boost, + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _object_constraints_vector( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma_e, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_lowpass_m, + q_highpass_e, + q_highpass_m, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Calls Object3DConstraints _object_constraints for each object.""" + xp = self._xp + + # electrostatic + current_object[0] = self._object_constraints( + current_object[0], + gaussian_filter, + gaussian_filter_sigma_e, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_highpass_e, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ) + + # magnetic + for index in range(1, 4): + current_object[index] = self._object_constraints( + current_object[index], + gaussian_filter, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_m, + q_highpass_m, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + False, + 0.0, + None, + **kwargs, + ) + + # divergence-free + current_object[1:] = project_vector_field_divergence_periodic_3D( + current_object[1:], xp=xp + ) + + return current_object + + def _constraints(self, current_object, current_probe, current_positions, **kwargs): + """Wrapper function to bypass _object_constraints""" + + current_object = self._object_constraints_vector(current_object, **kwargs) + current_probe = self._probe_constraints(current_probe, **kwargs) + current_positions = self._positions_constraints(current_positions, **kwargs) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma_e: float = None, + gaussian_filter_sigma_m: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass_e: float = None, + q_lowpass_m: float = None, + q_highpass_e: float = None, + q_highpass_m: float = None, + butterworth_order: float = 2, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + detector_fourier_mask: np.ndarray = None, + tv_denoise: bool = True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + collective_measurement_updates: bool = True, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma_e: float + Standard deviation of gaussian kernel for electrostatic object in A + gaussian_filter_sigma_m: float + Standard deviation of gaussian kernel for magnetic object in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + object_positivity: bool, optional + If True, forces object to be positive + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_measurement_updates: bool + if True perform collective tilt updates + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + + Returns + -------- + self: OverlapMagneticTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if use_projection_scheme: + raise NotImplementedError( + "Magnetic ptychographic tomography is currently only implemented for gradient descent." + ) + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is None: + detector_fourier_mask = xp.ones(self._amplitudes[0].shape) + else: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + if gaussian_filter_sigma_m is None: + gaussian_filter_sigma_m = gaussian_filter_sigma_e + + if q_lowpass_m is None: + q_lowpass_m = q_lowpass_e + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + indices = np.arange(self._num_measurements) + np.random.shuffle(indices) + + old_rot_matrix = np.eye(3) # identity + + for index in indices: + self._active_measurement_index = index + + measurement_error = 0.0 + + rot_matrix = self._tilt_orientation_matrices[ + self._active_measurement_index + ] + self._object = self._rotate_zxy_volume_util( + self._object, + rot_matrix @ old_rot_matrix.T, + ) + object_V = self._object[0] + + # last transformation matrix row + weight_x, weight_y, weight_z = rot_matrix[-1] + object_A = ( + weight_x * self._object[2] + + weight_y * self._object[3] + + weight_z * self._object[1] + ) + + object_sliced = self._project_sliced_object( + object_V + object_A, self._num_slices + ) + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + if not use_projection_scheme: + object_sliced_old = object_sliced.copy() + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + object_sliced, _probe = self._adjoint( + object_sliced, + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions and a0 > 0: + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + ) + + measurement_error += batch_error + + if not use_projection_scheme: + object_sliced -= object_sliced_old + + object_update = self._expand_sliced_object( + object_sliced, self._num_voxels + ) + + weights = (1, weight_z, weight_x, weight_y) + for index, weight in zip(range(4), weights): + if collective_measurement_updates: + collective_object[index] += self._rotate_zxy_volume( + object_update * weight, + rot_matrix.T, + ) + else: + self._object[index] += object_update * weight + + old_rot_matrix = rot_matrix + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self._object = self._rotate_zxy_volume_util(self._object, old_rot_matrix.T) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints_vector( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + orientation_matrix=None, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + # get scaled arrays + + if orientation_matrix is not None: + ordered_obj = self._rotate_zxy_volume_vector( + self._object, + orientation_matrix, + ) + + # V(z,x,y), Ax(z,x,y), Ay(z,x,y), Az(z,x,y) + ordered_obj = asnumpy(ordered_obj) + ordered_obj[1:] = np.roll(ordered_obj[1:], -1, axis=0) + + else: + # V(z,x,y), Ax(z,x,y), Ay(z,x,y), Az(z,x,y) + ordered_obj = self.object.copy() + ordered_obj[1:] = np.roll(ordered_obj[1:], -1, axis=0) + + _, nz, nx, ny = ordered_obj.shape + img_array = np.zeros((nx + nx + nz, ny * 4), dtype=ordered_obj.dtype) + + axes = [1, 2, 0] + transposes = [False, True, False] + labels = [("z [A]", "y [A]"), ("x [A]", "z [A]"), ("x [A]", "y [A]")] + limits_v = [(0, nz), (nz, nz + nx), (nz + nx, nz + nx + nx)] + limits_h = [(0, ny), (0, nz), (0, ny)] + + titles = [ + [ + r"$V$ projected along $\hat{x}$", + r"$A_x$ projected along $\hat{x}$", + r"$A_y$ projected along $\hat{x}$", + r"$A_z$ projected along $\hat{x}$", + ], + [ + r"$V$ projected along $\hat{y}$", + r"$A_x$ projected along $\hat{y}$", + r"$A_y$ projected along $\hat{y}$", + r"$A_z$ projected along $\hat{y}$", + ], + [ + r"$V$ projected along $\hat{z}$", + r"$A_x$ projected along $\hat{z}$", + r"$A_y$ projected along $\hat{z}$", + r"$A_z$ projected along $\hat{z}$", + ], + ] + + for index in range(4): + for axis, transpose, limit_v, limit_h in zip( + axes, transposes, limits_v, limits_h + ): + start_v, end_v = limit_v + start_h, end_h = np.array(limit_h) + index * ny + + subarray = ordered_obj[index].sum(axis) + if transpose: + subarray = subarray.T + + img_array[start_v:end_v, start_h:end_h] = subarray + + if plot_convergence: + auto_figsize = (ny * 4 * 4 / nx, (nx + nx + nz) * 3.5 / nx + 1) + else: + auto_figsize = (ny * 4 * 4 / nx, (nx + nx + nz) * 3.5 / nx) + + figsize = kwargs.pop("figsize", auto_figsize) + cmap_e = kwargs.pop("cmap_e", "magma") + cmap_m = kwargs.pop("cmap_m", "PuOr") + vmin_e = kwargs.pop("vmin_e", None) + vmax_e = kwargs.pop("vmax_e", None) + + # remove common unused kwargs + kwargs.pop("plot_probe", None) + kwargs.pop("plot_fourier_probe", None) + kwargs.pop("remove_initial_probe_aberrations", None) + kwargs.pop("vertical_lims", None) + kwargs.pop("horizontal_lims", None) + + _, vmin_e, vmax_e = return_scaled_histogram_ordering( + img_array[:, :ny], vmin_e, vmax_e + ) + + _, _, _vmax_m = return_scaled_histogram_ordering(np.abs(img_array[:, ny:])) + vmin_m = kwargs.pop("vmin_m", -_vmax_m) + vmax_m = kwargs.pop("vmax_m", _vmax_m) + + if plot_convergence: + spec = GridSpec( + ncols=4, + nrows=4, + height_ratios=[nx, nz, nx, nx / 4], + hspace=0.15, + wspace=0.35, + ) + else: + spec = GridSpec( + ncols=4, nrows=3, height_ratios=[nx, nz, nx], hspace=0.15, wspace=0.35 + ) + + if fig is None: + fig = plt.figure(figsize=figsize) + + for sp in spec: + row, col = np.unravel_index(sp.num1, (4, 4)) + + if row < 3: + ax = fig.add_subplot(sp) + + start_v, end_v = limits_v[row] + start_h, end_h = np.array(limits_h[row]) + col * ny + subarray = img_array[start_v:end_v, start_h:end_h] + + extent = [ + 0, + self.sampling[1] * subarray.shape[1], + self.sampling[0] * subarray.shape[0], + 0, + ] + + im = ax.imshow( + subarray, + cmap=cmap_e if sp.is_first_col() else cmap_m, + vmin=vmin_e if sp.is_first_col() else vmin_m, + vmax=vmax_e if sp.is_first_col() else vmax_m, + extent=extent, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + ax.set_title(titles[row][col]) + + y_label, x_label = labels[row] + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[-1, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _rotate_zxy_volume_util( + self, + current_object, + rot_matrix, + ): + """ """ + for index in range(4): + current_object[index] = self._rotate_zxy_volume( + current_object[index], rot_matrix + ) + + return current_object + + def _rotate_zxy_volume_vector(self, current_object, rot_matrix): + """Rotates vector field consistently. Note this is very expensive""" + + xp = self._xp + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + if xp is np: + from scipy.interpolate import RegularGridInterpolator + + current_object = self._asnumpy(current_object) + else: + try: + from cupyx.scipy.interpolate import RegularGridInterpolator + except ModuleNotFoundError: + from scipy.interpolate import RegularGridInterpolator + + xp = np # force xp to np for cupy <12.0 + current_object = self._asnumpy(current_object) + + _, nz, nx, ny = current_object.shape + + z, x, y = [xp.linspace(-1, 1, s, endpoint=False) for s in (nx, ny, nz)] + Z, X, Y = xp.meshgrid(z, x, y, indexing="ij") + coords = xp.array([Z.ravel(), X.ravel(), Y.ravel()]) + + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix @ swap_zxy_to_xyz) + rotated_vecs = tf.T.dot(coords).T + + Az = RegularGridInterpolator( + (z, x, y), current_object[1], bounds_error=False, fill_value=0 + ) + Ax = RegularGridInterpolator( + (z, x, y), current_object[2], bounds_error=False, fill_value=0 + ) + Ay = RegularGridInterpolator( + (z, x, y), current_object[3], bounds_error=False, fill_value=0 + ) + + xp = self._xp # switch back to device + obj = xp.zeros_like(current_object) + obj[0] = self._rotate_zxy_volume(xp.asarray(current_object[0]), rot_matrix) + + obj[1] = xp.asarray(Az(rotated_vecs).reshape(nz, nx, ny)) + obj[2] = xp.asarray(Ax(rotated_vecs).reshape(nz, nx, ny)) + obj[3] = xp.asarray(Ay(rotated_vecs).reshape(nz, nx, ny)) + + return obj diff --git a/py4DSTEM/process/phase/magnetic_ptychography.py b/py4DSTEM/process/phase/magnetic_ptychography.py new file mode 100644 index 000000000..2e887739f --- /dev/null +++ b/py4DSTEM/process/phase/magnetic_ptychography.py @@ -0,0 +1,1936 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely magnetic ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MagneticPtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Iterative Magnetic Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed electrostatic dimensions : (Px,Py) + Reconstructed magnetic dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + datacube: Sequence[DataCube] + Tuple of input 4D diffraction pattern intensities + energy: float + The electron energy of the wave functions in eV + magnetic_contribution_sign: str, optional + One of '-+', '-0+', '0+' + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad objects with + If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (2,Px,Py) + If None, initialized to 1.0j for complex objects and 0.0 for potential objects + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + positions_offset_ang: np.ndarray, optional + Offset of positions in A + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_magnetic_contribution_sign",) + + def __init__( + self, + energy: float, + datacube: Sequence[DataCube] = None, + magnetic_contribution_sign: str = "-+", + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + positions_offset_ang: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "magnetic_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._magnetic_contribution_sign = magnetic_contribution_sign + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: Sequence[np.ndarray] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + progress_bar: bool = True, + object_fov_mask: np.ndarray = True, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._magnetic_contribution_sign == "-+": + self._recon_mode = 0 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic vector potential sign in first meaurement assumed to be negative.\n" + "Magnetic vector potential sign in second meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "-0+": + self._recon_mode = 1 + self._num_measurements = 3 + magnetic_contribution_msg = ( + "Magnetic vector potential sign in first meaurement assumed to be negative.\n" + "Magnetic vector potential assumed to be zero in second meaurement.\n" + "Magnetic vector potential sign in third meaurement assumed to be positive." + ) + + elif self._magnetic_contribution_sign == "0+": + self._recon_mode = 2 + self._num_measurements = 2 + magnetic_contribution_msg = ( + "Magnetic vector potential assumed to be zero in first meaurement.\n" + "Magnetic vector potential sign in second meaurement assumed to be positive." + ) + else: + raise ValueError( + f"magnetic_contribution_sign must be either '-+', '-0+', or '0+', not {self._magnetic_contribution_sign}" + ) + + if self._verbose: + warnings.warn( + magnetic_contribution_msg, + UserWarning, + ) + + if len(self._datacube) != self._num_measurements: + raise ValueError( + f"datacube must be the same length as magnetic_contribution_sign, not length {len(self._datacube)}." + ) + + dc_shapes = [dc.shape for dc in self._datacube] + if dc_shapes.count(dc_shapes[0]) != self._num_measurements: + raise ValueError("datacube intensities must be the same size.") + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if force_com_measured is None: + force_com_measured = [None] * self._num_measurements + + if self._scan_positions is None: + self._scan_positions = [None] * self._num_measurements + + if self._positions_offset_ang is None: + self._positions_offset_ang = [None] * self._num_measurements + + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + + if progress_bar: + # turn off verbosity to play nice with tqdm + verbose = self._verbose + self._verbose = False + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="measurement", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first measurement + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured[index], + ) + + # estimate rotation / transpose using first measurement + if index == 0: + # silence warnings to play nice with progress bar + verbose = self._verbose + self._verbose = False + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + _com_x, + _com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + com_measured_x, + com_measured_y, + com_normalized_x, + com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=False, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + self._verbose = verbose + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + self._positions_offset_ang[index], + ) + + if progress_bar: + # reset verbosity + self._verbose = verbose + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + obj = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + ) + + if self._object is None: + self._object = xp.full((2,) + obj.shape, obj) + else: + self._object = obj + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + # initialize object_fov_mask + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + probe_overlap, + extent=extent, + cmap="gray", + ) + ax2.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_xlim((extent[0], extent[1])) + ax2.set_ylim((extent[2], extent[3])) + ax2.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = xp.empty( + (self._num_measurements,) + shifted_probes.shape, dtype=current_object.dtype + ) + object_patches[0] = current_object[ + 0, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + object_patches[1] = current_object[ + 1, vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + overlap_base = shifted_probes * object_patches[0] + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + overlap = overlap_base * xp.conj(object_patches[1]) + case (0, 1) | (1, 2) | (2, 1): # forward + overlap = overlap_base * object_patches[1] + case (1, 1) | (2, 0): # neutral + overlap = overlap_base + case _: + raise ValueError() + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_conj = xp.conj(shifted_probes) # P* + electrostatic_conj = xp.conj(object_patches[0]) # V* = exp(-i v) + + probe_electrostatic_abs = xp.abs(shifted_probes * object_patches[0]) + probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts( + probe_electrostatic_abs**2, + positions_px, + ) + probe_electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2 + + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2 + ) + + probe_magnetic_abs = xp.abs(shifted_probes * object_patches[1]) + probe_magnetic_normalization = self._sum_overlapping_patches_bincounts( + probe_magnetic_abs**2, + positions_px, + ) + probe_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_magnetic_normalization) ** 2 + + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2 + ) + + if not fix_probe: + electrostatic_magnetic_abs = xp.abs(object_patches[0] * object_patches[1]) + electrostatic_magnetic_normalization = xp.sum( + electrostatic_magnetic_abs**2, + axis=0, + ) + electrostatic_magnetic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_magnetic_normalization)) + ** 2 + ) + + if self._recon_mode > 0: + electrostatic_abs = xp.abs(object_patches[0]) + electrostatic_normalization = xp.sum( + electrostatic_abs**2, + axis=0, + ) + electrostatic_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * electrostatic_normalization) ** 2 + + (normalization_min * xp.max(electrostatic_normalization)) ** 2 + ) + + match (self._recon_mode, self._active_measurement_index): + case (0, 0) | (1, 0): # reverse + if self._object_type == "potential": + # -i exp(-i v) exp(i m) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * object_patches[1] + * electrostatic_conj + * probe_conj + * exit_waves + ), + positions_px, + ) + + # i exp(-i v) exp(i m) P* + magnetic_update = -electrostatic_update + + else: + # M P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * object_patches[1] * exit_waves, + positions_px, + ) + + # V* P* + magnetic_update = xp.conj( + self._sum_overlapping_patches_bincounts( + probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # M V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * object_patches[1] * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (0, 1) | (1, 2) | (2, 1): # forward + magnetic_conj = xp.conj(object_patches[1]) # M* = exp(-i m) + + if self._object_type == "potential": + # -i exp(-i v) exp(-i m) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * magnetic_conj + * electrostatic_conj + * probe_conj + * exit_waves + ), + positions_px, + ) + + # -i exp(-i v) exp(-i m) P* + magnetic_update = electrostatic_update + + else: + # M* P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * magnetic_conj * exit_waves, + positions_px, + ) + + # V* P* + magnetic_update = self._sum_overlapping_patches_bincounts( + probe_conj * electrostatic_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_magnetic_normalization + ) + current_object[1] += ( + step_size * magnetic_update * probe_electrostatic_normalization + ) + + if not fix_probe: + # M* V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * magnetic_conj * exit_waves, + axis=0, + ) + * electrostatic_magnetic_normalization + ) + + case (1, 1) | (2, 0): # neutral + probe_abs = xp.abs(shifted_probes) + probe_normalization = self._sum_overlapping_patches_bincounts( + probe_abs**2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + # -i exp(-i v) P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + xp.real(-1j * electrostatic_conj * probe_conj * exit_waves), + positions_px, + ) + + else: + # P* + electrostatic_update = self._sum_overlapping_patches_bincounts( + probe_conj * exit_waves, + positions_px, + ) + + current_object[0] += ( + step_size * electrostatic_update * probe_normalization + ) + + if not fix_probe: + # V* + current_probe += step_size * ( + xp.sum( + electrostatic_conj * exit_waves, + axis=0, + ) + * electrostatic_normalization + ) + + case _: + raise ValueError() + + return current_object, current_probe + + def _object_constraints( + self, + current_object, + pure_phase_object, + gaussian_filter, + gaussian_filter_sigma_e, + gaussian_filter_sigma_m, + butterworth_filter, + butterworth_order, + q_lowpass_e, + q_lowpass_m, + q_highpass_e, + q_highpass_m, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """MagneticObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object[0] = self._object_gaussian_constraint( + current_object[0], gaussian_filter_sigma_e, pure_phase_object + ) + current_object[1] = self._object_gaussian_constraint( + current_object[1], gaussian_filter_sigma_m, True + ) + if butterworth_filter: + current_object[0] = self._object_butterworth_constraint( + current_object[0], + q_lowpass_e, + q_highpass_e, + butterworth_order, + ) + current_object[1] = self._object_butterworth_constraint( + current_object[1], + q_lowpass_m, + q_highpass_m, + butterworth_order, + ) + if tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], tv_denoise_weight, tv_denoise_inner_iter + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object[0] = self._object_shrinkage_constraint( + current_object[0], + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object[0] = self._object_threshold_constraint( + current_object[0], pure_phase_object + ) + current_object[1] = self._object_threshold_constraint( + current_object[1], True + ) + elif object_positivity: + current_object[0] = self._object_positivity_constraint(current_object[0]) + + return current_object + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + pure_phase_object: bool = False, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma_e: float = None, + gaussian_filter_sigma_m: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass_e: float = None, + q_lowpass_m: float = None, + q_highpass_e: float = None, + q_highpass_m: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + detector_fourier_mask: np.ndarray = None, + store_iterations: bool = False, + collective_measurement_updates: bool = True, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma_e: float + Standard deviation of gaussian kernel for electrostatic object in A + gaussian_filter_sigma_m: float + Standard deviation of gaussian kernel for magnetic object in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass_e: float + Cut-off frequency in A^-1 for low-pass filtering electrostatic object + q_lowpass_m: float + Cut-off frequency in A^-1 for low-pass filtering magnetic object + q_highpass_e: float + Cut-off frequency in A^-1 for high-pass filtering electrostatic object + q_highpass_m: float + Cut-off frequency in A^-1 for high-pass filtering magnetic object + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + collective_measurement_updates: bool + if True perform collective updates for all measurements + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + if not collective_measurement_updates and self._verbose: + warnings.warn( + "Magnetic ptychography is much more robust with `collective_measurement_updates=True`.", + UserWarning, + ) + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if use_projection_scheme: + raise NotImplementedError( + "Magnetic ptychography is currently only implemented for gradient descent." + ) + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if object_type is not None: + self._switch_object_type(object_type) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + step_size, + max_batch_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is None: + detector_fourier_mask = xp.ones(self._amplitudes[0].shape) + else: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + if gaussian_filter_sigma_m is None: + gaussian_filter_sigma_m = gaussian_filter_sigma_e + + if q_lowpass_m is None: + q_lowpass_m = q_lowpass_e + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + # randomize + measurement_indices = np.arange(self._num_measurements) + np.random.shuffle(measurement_indices) + + for measurement_index in measurement_indices: + self._active_measurement_index = measurement_index + + measurement_error = 0.0 + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme=use_projection_scheme, + projection_a=projection_a, + projection_b=projection_b, + projection_c=projection_c, + ) + + # adjoint operator + object_update, _probe = self._adjoint( + self._object.copy(), + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + object_update -= self._object + + # position correction + if not fix_positions and a0 > 0: + self._positions_px_all[batch_indices] = ( + self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + ) + + measurement_error += batch_error + + if collective_measurement_updates: + collective_object += object_update + else: + self._object += object_update + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + pure_phase_object=pure_phase_object + and self._object_type == "complex", + ) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma_m is not None, + gaussian_filter_sigma_e=gaussian_filter_sigma_e, + gaussian_filter_sigma_m=gaussian_filter_sigma_m, + butterworth_filter=butterworth_filter + and (q_lowpass_m is not None or q_highpass_m is not None), + q_lowpass_e=q_lowpass_e, + q_lowpass_m=q_lowpass_m, + q_highpass_e=q_highpass_e, + q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + pure_phase_object=pure_phase_object + and self._object_type == "complex", + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _visualize_all_iterations(self, **kwargs): + raise NotImplementedError() + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (12, 5)) + cmap_e = kwargs.pop("cmap_e", "magma") + cmap_m = kwargs.pop("cmap_m", "PuOr") + chroma_boost = kwargs.pop("chroma_boost", 1) + + # get scaled arrays + probe = self._return_single_probe() + obj = self.object_cropped + if self._object_type == "complex": + obj = np.angle(obj) + + vmin_e = kwargs.pop("vmin_e", None) + vmax_e = kwargs.pop("vmax_e", None) + obj[0], vmin_e, vmax_e = return_scaled_histogram_ordering( + obj[0], vmin_e, vmax_e + ) + + _, _, _vmax_m = return_scaled_histogram_ordering(np.abs(obj[1])) + vmin_m = kwargs.pop("vmin_m", -_vmax_m) + vmax_m = kwargs.pop("vmax_m", _vmax_m) + + extent = [ + 0, + self.sampling[1] * obj.shape[2], + self.sampling[0] * obj.shape[1], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=3, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=2, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object_e + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0], + extent=extent, + cmap=cmap_e, + vmin=vmin_e, + vmax=vmax_e, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Electrostatic potential") + elif self._object_type == "complex": + ax.set_title("Electrostatic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_m + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1], + extent=extent, + cmap=cmap_m, + vmin=vmin_m, + vmax=vmax_m, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + ax = fig.add_subplot(spec[0, 2]) + if plot_fourier_probe: + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title("Reconstructed probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + # Object_e + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj[0], + extent=extent, + cmap=cmap_e, + vmin=vmin_e, + vmax=vmax_e, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Electrostatic potential") + elif self._object_type == "complex": + ax.set_title("Electrostatic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Object_e + ax = fig.add_subplot(spec[0, 1]) + im = ax.imshow( + obj[1], + extent=extent, + cmap=cmap_m, + vmin=vmin_m, + vmax=vmax_m, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Magnetic potential") + elif self._object_type == "complex": + ax.set_title("Magnetic phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + ax = fig.add_subplot(spec[1, :]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + @property + def object_cropped(self): + """Cropped and rotated object""" + avg_pos = self._return_average_positions() + cropped_e = self._crop_rotate_object_fov(self._object[0], positions_px=avg_pos) + cropped_m = self._crop_rotate_object_fov(self._object[1], positions_px=avg_pos) + + return np.array([cropped_e, cropped_m]) diff --git a/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..119fc3a3c --- /dev/null +++ b/py4DSTEM/process/phase/mixedstate_multislice_ptychography.py @@ -0,0 +1,1122 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = None + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ProbeMixedConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + Object2p5DMethodsMixin, + Object2p5DProbeMixedMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ObjectNDProbeMixedMethodsMixin, + ProbeMethodsMixin, + ProbeMixedMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MixedstateMultislicePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeMixedConstraintsMixin, + ProbeConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + Object2p5DProbeMixedMethodsMixin, + ObjectNDProbeMixedMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMixedMethodsMixin, + ProbeMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + positions_offset_ang: np.ndarray, optional + Offset of positions in A + theta_x: float + x tilt of propagator in mrad + theta_y: float + y tilt of propagator in mrad + middle_focus: bool + if True, adds half the sample thickness to the defocus + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_probes", + "_num_slices", + "_slice_thicknesses", + "_theta_x", + "_theta_y", + ) + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + positions_offset_ang: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if middle_focus: + half_thickness = slice_thicknesses.mean() * num_slices / 2 + self._polar_parameters["C10"] -= half_thickness + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._positions_mask = positions_mask + self._object_padding_px = object_padding_px + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + If not None, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + force_com_measured, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + com_measured=force_com_measured, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + self._positions_offset_ang, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._num_slices, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + # precompute propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + self._theta_x, + self._theta_y, + ) + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe[0] intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax2, chroma_boost=chroma_boost) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe[0] intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + num_probes_fit_aberrations: int = np.inf, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter: bool = True, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices: bool = False, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + detector_fourier_mask: np.ndarray = None, + pure_phase_object: bool = False, + tv_denoise_chambolle: bool = True, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=1, + tv_denoise: bool = True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + num_probes_fit_aberrations: int + The number of probes on which to apply the probe fitting + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool, optional + If True and kz_regularization_gamma is not None, applies kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices: bool, optional + If True, object forced to identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + pure_phase_object: bool, optional + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True and tv_denoise_weight_chambolle is not None, object is smoothed using TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: int + If not None, pads object at top and bottom with this many zeros before applying denoising + tv_denoise: bool + If True and tv_denoise_weights is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + # initialization + self._reset_reconstruction(store_iterations, reset) + + if object_type is not None: + self._switch_object_type(object_type) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is None: + detector_fourier_mask = xp.ones(self._amplitudes[0].shape) + else: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + num_probes_fit_aberrations=num_probes_fit_aberrations, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=kz_regularization_filter + and kz_regularization_gamma is not None, + kz_regularization_gamma=( + kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma + ), + identical_slices=identical_slices, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + pure_phase_object=pure_phase_object and self._object_type == "complex", + tv_denoise_chambolle=tv_denoise_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/mixedstate_ptychography.py b/py4DSTEM/process/phase/mixedstate_ptychography.py new file mode 100644 index 000000000..bd650a931 --- /dev/null +++ b/py4DSTEM/process/phase/mixedstate_ptychography.py @@ -0,0 +1,1003 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely mixed-state ptychography. +""" + +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ProbeMixedConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ObjectNDProbeMixedMethodsMixin, + ProbeMethodsMixin, + ProbeMixedMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MixedstatePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeMixedConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + ObjectNDProbeMixedMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMixedMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Mixed-State Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + datacube: DataCube + Input 4D diffraction pattern intensities + num_probes: int, optional + Number of mixed-state probes + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + positions_offset_ang: np.ndarray, optional + Offset of positions in A + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes",) + + def __init__( + self, + energy: float, + datacube: DataCube = None, + num_probes: int = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + positions_offset_ang: np.ndarray = None, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "mixed-state_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + force_com_measured, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + com_measured=force_com_measured, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + self._positions_offset_ang, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe[0], positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered, + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, axs = plt.subplots(1, self._num_probes + 1, figsize=figsize) + + for i in range(self._num_probes): + axs[i].imshow( + complex_probe_rgb[i], + extent=probe_extent, + ) + axs[i].set_ylabel("x [A]") + axs[i].set_xlabel("y [A]") + axs[i].set_title(f"Initial probe[{i}] intensity") + + divider = make_axes_locatable(axs[i]) + cax = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax, chroma_boost=chroma_boost) + + axs[-1].imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + axs[-1].scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + axs[-1].set_ylabel("x [A]") + axs[-1].set_xlabel("y [A]") + axs[-1].set_xlim((extent[0], extent[1])) + axs[-1].set_ylim((extent[2], extent[3])) + axs[-1].set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + pure_phase_object: bool = False, + fix_probe_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + global_affine_transformation: bool = False, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + num_probes_fit_aberrations: int = np.inf, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + detector_fourier_mask: np.ndarray = None, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe: bool, optional + If True, probe is fixed + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: int, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + num_probes_fit_aberrations: int + The number of probes on which to apply the probe fitting + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + If not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + # initialization + self._reset_reconstruction(store_iterations, reset) + + if object_type is not None: + self._switch_object_type(object_type) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is None: + detector_fourier_mask = xp.ones(self._amplitudes[0].shape) + else: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + num_probes_fit_aberrations=num_probes_fit_aberrations, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + orthogonalize_probe=orthogonalize_probe, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + pure_phase_object=pure_phase_object and self._object_type == "complex", + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/multislice_ptychography.py b/py4DSTEM/process/phase/multislice_ptychography.py new file mode 100644 index 000000000..db17cb1a8 --- /dev/null +++ b/py4DSTEM/process/phase/multislice_ptychography.py @@ -0,0 +1,1091 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class MultislicePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + positions_offset_ang: np.ndarray, optional + Offset of positions in A + theta_x: float + x tilt of propagator in mrad + theta_y: float + y tilt of propagator in mrad + middle_focus: bool + if True, adds half the sample thickness to the defocus + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_slices", + "_slice_thicknesses", + "_theta_x", + "_theta_y", + ) + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + positions_offset_ang: np.ndarray = None, + theta_x: float = None, + theta_y: float = None, + middle_focus: bool = False, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if middle_focus: + half_thickness = slice_thicknesses.mean() * num_slices / 2 + self._polar_parameters["C10"] -= half_thickness + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._positions_mask = positions_mask + self._object_padding_px = object_padding_px + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + If not None, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + force_com_measured, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + com_measured=force_com_measured, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + self._positions_offset_ang, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._num_slices, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + # precompute propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + self._theta_x, + self._theta_y, + ) + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered, + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe.copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, + chroma_boost=chroma_boost, + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter: bool = True, + kz_regularization_gamma: float = None, + identical_slices: bool = False, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + detector_fourier_mask: np.ndarray = None, + pure_phase_object: bool = False, + tv_denoise_chambolle: bool = True, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=1, + tv_denoise: bool = True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vacuum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool, optional + If True and kz_regularization_gamma is not None, applies kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices: int, optional + If True, object forced to identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + pure_phase_object: bool, optional + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True and tv_denoise_weight_chambolle is not None, object is smoothed using TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: int + If not None, pads object at top and bottom with this many zeros before applying denoising + tv_denoise: bool + If True and tv_denoise_weights is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + If not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + # initialization + self._reset_reconstruction(store_iterations, reset) + + if object_type is not None: + self._switch_object_type(object_type) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is None: + detector_fourier_mask = xp.ones(self._amplitudes[0].shape) + else: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=kz_regularization_filter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma, + identical_slices=identical_slices, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + pure_phase_object=pure_phase_object and self._object_type == "complex", + tv_denoise_chambolle=tv_denoise_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/parallax.py similarity index 60% rename from py4DSTEM/process/phase/iterative_parallax.py rename to py4DSTEM/process/phase/parallax.py index 4877512d7..060a151aa 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -10,15 +10,24 @@ import numpy as np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec +from matplotlib.ticker import PercentFormatter from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction -from py4DSTEM.process.phase.utils import AffineTransform +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import ( + AffineTransform, + bilinear_kernel_density_estimate, + bilinearly_interpolate_array, + lanczos_interpolate_array, + lanczos_kernel_density_estimate, + pixel_rolling_kernel_density_estimate, +) from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from py4DSTEM.visualize import show +from py4DSTEM.visualize import return_scaled_histogram_ordering, show from scipy.linalg import polar +from scipy.ndimage import distance_transform_edt from scipy.optimize import minimize from scipy.special import comb @@ -27,8 +36,6 @@ except (ModuleNotFoundError, ImportError): cp = np -warnings.simplefilter(action="always", category=UserWarning) - _aberration_names = { (1, 0): "C1 ", (1, 2): "stig ", @@ -47,7 +54,7 @@ } -class ParallaxReconstruction(PhaseReconstruction): +class Parallax(PhaseReconstruction): """ Iterative parallax reconstruction class. @@ -70,27 +77,23 @@ def __init__( self, energy: float, datacube: DataCube = None, - verbose: bool = False, + verbose: bool = True, object_padding_px: Tuple[int, int] = (32, 32), device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, name: str = "parallax_reconstruction", ): Custom.__init__(self, name=name) - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter + if storage is None: + storage = device - self._gaussian_filter = gaussian_filter - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + if storage != device: + raise NotImplementedError() - self._gaussian_filter = gaussian_filter - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_device(device, clear_fft_cache) + self.set_storage(storage) self.set_save_defaults() @@ -100,7 +103,6 @@ def __init__( # Metadata self._energy = energy self._verbose = verbose - self._device = device self._object_padding_px = object_padding_px self._preprocessed = False @@ -151,13 +153,10 @@ def to_h5(self, group): data=self._asnumpy(self._recon_BF_subpixel_aligned), ) - if hasattr(self, "aberration_dict"): + if hasattr(self, "aberration_dict_cartesian"): self.metadata = Metadata( - name="aberrations_metadata", - data={ - v["aberration name"]: v["value [Ang]"] - for k, v in self.aberration_dict.items() - }, + name="aberrations_polar_metadata", + data=self.aberration_dict_polar, ) self.metadata = Metadata( @@ -210,6 +209,8 @@ def _get_constructor_args(cls, group): "name": instance_md["name"], "verbose": True, # for compatibility "device": "cpu", # for compatibility + "storage": "cpu", # for compatibility + "clear_fft_cache": True, # for compatibility } return kwargs @@ -251,14 +252,20 @@ def _populate_instance(self, group): def preprocess( self, - edge_blend: int = 16, + edge_blend: float = 16.0, + dp_mask: np.ndarray = None, threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, - descan_correct: bool = True, + descan_correction_fit_function: str = None, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, + realspace_mask: np.ndarray = None, + apply_realspace_mask_to_stack: bool = True, + vectorized_com_calculation: bool = True, + device: str = None, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -266,8 +273,10 @@ def preprocess( Parameters ---------- - edge_blend: int, optional - Pixels to blend image at the border + edge_blend: float, optional + Number of pixels to blend image at the border + dp_mask: np.ndarray, bool + Bright-field pixels mask used for cross-correlation, boolean array same shape as DPs threshold: float, optional Fraction of max of dp_mean for bright-field pixels normalize_images: bool, optional @@ -278,13 +287,26 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus - descan_correct: float, optional - If True, aligns bright field stack based on measured descan + descan_correction_fit_function: str, optional + If not None, descan correction will be performed using fit function. + One of "constant", "plane", "parabola", or "bezier_two". rotation_guess: float, optional Initial guess of defocus value in degrees If None, first iteration assumed to be 0 plot_average_bf: bool, optional If True, plots the average bright field image, using defocus_guess + realspace_mask: np.array, optional + If this array is provided, pixels in real space set to false will be + set to zero in the virtual bright field images. + apply_realspace_mask_to_stack: bool, optional + If this value is set to true, output BF images will be masked by + the edge filter and realspace_mask if it is passed in. + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -292,7 +314,11 @@ def preprocess( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp + device = self._device asnumpy = self._asnumpy if self._datacube is None: @@ -303,35 +329,19 @@ def preprocess( ) ) - # get mean diffraction pattern - try: - self._dp_mean = xp.asarray( - self._datacube.tree("dp_mean").data, dtype=xp.float32 - ) - except AssertionError: - self._dp_mean = xp.asarray( - self._datacube.get_dp_mean().data, dtype=xp.float32 - ) - # extract calibrations self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, ) + self._intensities = xp.asarray(self._intensities) + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) self._scan_shape = np.array(self._intensities.shape[:2]) - # make sure mean diffraction pattern is shaped correctly - if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( - self._dp_mean.shape[1] != self._intensities.shape[3] - ): - raise ValueError( - "dp_mean must match the datacube shape. Try setting dp_mean = None." - ) - - # descan correct - if descan_correct: + # descan correction + if descan_correction_fit_function is not None: ( _, _, @@ -342,9 +352,10 @@ def preprocess( ) = self._calculate_intensities_center_of_mass( self._intensities, dp_mask=None, - fit_function="plane", + fit_function=descan_correction_fit_function, com_shifts=None, com_measured=None, + vectorized_calculation=vectorized_com_calculation, ) com_fitted_x = asnumpy(com_fitted_x) @@ -352,7 +363,8 @@ def preprocess( intensities = asnumpy(self._intensities) intensities_shifted = np.zeros_like(intensities) - center_x, center_y = self._region_of_interest_shape / 2 + center_x = com_fitted_x.mean() + center_y = com_fitted_y.mean() for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): @@ -367,10 +379,14 @@ def preprocess( intensities_shifted[rx, ry] = intensity_shifted self._intensities = xp.asarray(intensities_shifted, xp.float32) - self._dp_mean = self._intensities.mean((0, 1)) + + if dp_mask is not None: + self._dp_mask = xp.asarray(dp_mask) + else: + dp_mean = self._intensities.mean((0, 1)) + self._dp_mask = dp_mean >= (xp.max(dp_mean) * threshold_intensity) # select virtual detector pixels - self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) self._num_bf_images = int(xp.count_nonzero(self._dp_mask)) self._wavelength = electron_wavelength_angstrom(self._energy) @@ -384,7 +400,13 @@ def preprocess( self._probe_angles = self._kxy * self._wavelength self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1)) - # Window function + # real space mask blending function + if realspace_mask is not None: + im_edge_dist = xp.array(distance_transform_edt(realspace_mask)) + self._window_mask = xp.minimum(im_edge_dist / edge_blend, 1.0) + self._window_mask = xp.sin(self._window_mask * (np.pi / 2)) ** 2 + + # edge window function x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:] x -= (x[1] - x[0]) / 2 wx = ( @@ -408,7 +430,12 @@ def preprocess( ** 2 ) self._window_edge = wx[:, None] * wy[None, :] - self._window_inv = 1 - self._window_edge + + # if needed, combine edge mask with the input real space mask + if realspace_mask is not None: + self._window_edge *= self._window_mask + + # derived window functions self._window_pad = xp.zeros( ( self._grid_scan_shape[0] + self._object_padding_px[0], @@ -422,6 +449,8 @@ def preprocess( self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + self._object_padding_px[1] // 2, ] = self._window_edge + self._window_inv = 1 - self._window_edge + self._window_inv_pad = 1 - self._window_pad # Collect BF images all_bfs = xp.moveaxis( @@ -430,19 +459,27 @@ def preprocess( (1, 2, 0), ) - # initalize + # initialize stack_shape = ( self._num_bf_images, self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: - self._stack_BF = xp.ones(stack_shape, dtype=xp.float32) - self._stack_BF_no_window = xp.ones(stack_shape, xp.float32) + self._normalized_stack = True + self._stack_BF_shifted = xp.ones(stack_shape, dtype=xp.float32) + self._stack_BF_unshifted = xp.ones(stack_shape, xp.float32) if normalize_order == 0: - all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] - self._stack_BF[ + # all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] + weights = xp.average( + all_bfs.reshape((self._num_bf_images, -1)), + weights=self._window_edge.ravel(), + axis=1, + ) + all_bfs /= weights[:, None, None] + + self._stack_BF_shifted[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + self._object_padding_px[0] // 2, @@ -452,13 +489,16 @@ def preprocess( self._window_inv[None] + self._window_edge[None] * all_bfs ) - self._stack_BF_no_window[ - :, - self._object_padding_px[0] // 2 : self._grid_scan_shape[0] - + self._object_padding_px[0] // 2, - self._object_padding_px[1] // 2 : self._grid_scan_shape[1] - + self._object_padding_px[1] // 2, - ] = all_bfs + if apply_realspace_mask_to_stack: + self._stack_BF_unshifted = self._stack_BF_shifted.copy() + else: + self._stack_BF_unshifted[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs elif normalize_order == 1: x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) @@ -466,15 +506,23 @@ def preprocess( ya, xa = xp.meshgrid(y, x) basis = np.vstack( ( - xp.ones_like(xa), + xp.ones_like(xa.ravel()), xa.ravel(), ya.ravel(), ) ).T + weights = np.sqrt(self._window_edge).ravel() + for a0 in range(all_bfs.shape[0]): - coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) + # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) + # weighted least squares + coefs = np.linalg.lstsq( + weights[:, None] * basis, + weights * all_bfs[a0].ravel(), + rcond=None, + ) - self._stack_BF[ + self._stack_BF_shifted[ a0, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + self._object_padding_px[0] // 2, @@ -486,19 +534,81 @@ def preprocess( basis @ coefs[0], all_bfs.shape[1:3] ) - self._stack_BF_no_window[ + if apply_realspace_mask_to_stack: + self._stack_BF_unshifted = self._stack_BF_shifted.copy() + else: + self._stack_BF_unshifted[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape( + basis @ coefs[0], all_bfs.shape[1:3] + ) + + elif normalize_order == 2: + x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) + y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32) + ya, xa = xp.meshgrid(y, x) + basis = np.vstack( + ( + 1 * xa.ravel() ** 2 * ya.ravel() ** 2, + 2 * xa.ravel() ** 2 * ya.ravel() * (1 - ya.ravel()), + 1 * xa.ravel() ** 2 * (1 - ya.ravel()) ** 2, + 2 * xa.ravel() * (1 - xa.ravel()) * ya.ravel() ** 2, + 4 + * xa.ravel() + * (1 - xa.ravel()) + * ya.ravel() + * (1 - ya.ravel()), + 2 * xa.ravel() * (1 - xa.ravel()) * (1 - ya.ravel()) ** 2, + 1 * (1 - xa.ravel()) ** 2 * ya.ravel() ** 2, + 2 * (1 - xa.ravel()) ** 2 * ya.ravel() * (1 - ya.ravel()), + 1 * (1 - xa.ravel()) ** 2 * (1 - ya.ravel()) ** 2, + ) + ).T + weights = np.sqrt(self._window_edge).ravel() + + for a0 in range(all_bfs.shape[0]): + # coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None) + # weighted least squares + coefs = np.linalg.lstsq( + weights[:, None] * basis, + weights * all_bfs[a0].ravel(), + rcond=None, + ) + + self._stack_BF_shifted[ a0, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + self._object_padding_px[0] // 2, self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + self._object_padding_px[1] // 2, - ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3]) + ] = self._window_inv[None] + self._window_edge[None] * all_bfs[ + a0 + ] / xp.reshape( + basis @ coefs[0], all_bfs.shape[1:3] + ) + if apply_realspace_mask_to_stack: + self._stack_BF_unshifted = self._stack_BF_shifted.copy() + else: + self._stack_BF_unshifted[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape( + basis @ coefs[0], all_bfs.shape[1:3] + ) else: + self._normalized_stack = False all_means = xp.mean(all_bfs, axis=(1, 2)) - self._stack_BF = xp.full(stack_shape, all_means[:, None, None]) - self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None]) - self._stack_BF[ + self._stack_BF_shifted = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_unshifted = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_shifted[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + self._object_padding_px[0] // 2, @@ -508,20 +618,22 @@ def preprocess( self._window_inv[None] * all_means[:, None, None] + self._window_edge[None] * all_bfs ) - - self._stack_BF_no_window[ - :, - self._object_padding_px[0] // 2 : self._grid_scan_shape[0] - + self._object_padding_px[0] // 2, - self._object_padding_px[1] // 2 : self._grid_scan_shape[1] - + self._object_padding_px[1] // 2, - ] = all_bfs + if apply_realspace_mask_to_stack: + self._stack_BF_unshifted = self._stack_BF_shifted.copy() + else: + self._stack_BF_unshifted[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs # Fourier space operators for image shifts - qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) + qx = xp.fft.fftfreq(self._stack_BF_shifted.shape[1], d=1) qx = xp.asarray(qx, dtype=xp.float32) - qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) + qy = xp.fft.fftfreq(self._stack_BF_shifted.shape[2], d=1) qy = xp.asarray(qy, dtype=xp.float32) qxa, qya = xp.meshgrid(qx, qy, indexing="ij") @@ -531,7 +643,7 @@ def preprocess( # Initialization utilities self._stack_mask = xp.tile(self._window_pad[None], (self._num_bf_images, 1, 1)) if defocus_guess is not None: - Gs = xp.fft.fft2(self._stack_BF) + Gs = xp.fft.fft2(self._stack_BF_shifted) self._xy_shifts = ( -self._probe_angles * defocus_guess / xp.array(self._scan_sampling) @@ -551,7 +663,7 @@ def preprocess( self._qx_shift[None] * dx[:, None, None] + self._qy_shift[None] * dy[:, None, None] ) - self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) + self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) @@ -560,7 +672,7 @@ def preprocess( else: self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) - self._stack_mean = xp.mean(self._stack_BF) + self._stack_mean = xp.mean(self._stack_BF_shifted) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images self._recon_mask = xp.sum(self._stack_mask, axis=0) @@ -568,18 +680,21 @@ def preprocess( self._recon_BF = ( self._stack_mean * mask_inv - + xp.sum(self._stack_BF * self._stack_mask, axis=0) + + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) ) / (self._recon_mask + mask_inv) self._recon_error = ( xp.atleast_1d( - xp.sum(xp.abs(self._stack_BF - self._recon_BF[None]) * self._stack_mask) + xp.sum( + xp.abs(self._stack_BF_shifted - self._recon_BF[None]) + * self._stack_mask + ) ) / self._mask_sum ) self._recon_BF_initial = self._recon_BF.copy() - self._stack_BF_initial = self._stack_BF.copy() + self._stack_BF_shifted_initial = self._stack_BF_shifted.copy() self._stack_mask_initial = self._stack_mask.copy() self._recon_mask_initial = self._recon_mask.copy() self._xy_shifts_initial = self._xy_shifts.copy() @@ -588,7 +703,6 @@ def preprocess( if plot_average_bf: figsize = kwargs.pop("figsize", (8, 4)) - fig, ax = plt.subplots(1, 2, figsize=figsize) self._visualize_figax(fig, ax[0], **kwargs) @@ -610,177 +724,18 @@ def preprocess( ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") plt.tight_layout() - self._preprocessed = True - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) return self - def tune_angle_and_defocus( - self, - angle_guess=None, - defocus_guess=None, - angle_step_size=5, - defocus_step_size=100, - num_angle_values=5, - num_defocus_values=5, - return_values=False, - plot_reconstructions=True, - plot_convergence=True, - **kwargs, - ): - """ - Run parallax reconstruction over a parameters space of pre-determined angles - and defocus - - Parameters - ---------- - angle_guess: float (degrees), optional - initial starting guess for rotation angle between real and reciprocal space - if None, uses 0 - defocus_guess: float (A), optional - initial starting guess for defocus (defocus dF) - if None, uses 0 - angle_step_size: float (degrees), optional - size of change of rotation angle between real and reciprocal space for - each step in parameter space - defocus_step_size: float (A), optional - size of change of defocus for each step in parameter space - num_angle_values: int, optional - number of values of angle to test, must be >= 1. - num_defocus_values: int,optional - number of values of defocus to test, must be >= 1 - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, makes 2D plot of error metrix - return_values: bool, optional - if True, returns objects, convergence - - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - asnumpy = self._asnumpy - - if angle_guess is None: - angle_guess = 0 - if defocus_guess is None: - defocus_guess = 0 - - if num_angle_values == 1: - angle_step_size = 0 - - if num_defocus_values == 1: - defocus_step_size = 0 - - angles = np.linspace( - angle_guess - angle_step_size * (num_angle_values - 1) / 2, - angle_guess + angle_step_size * (num_angle_values - 1) / 2, - num_angle_values, - ) - - defocus_values = np.linspace( - defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2, - defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2, - num_defocus_values, - ) - if return_values or plot_convergence: - recon_BF = [] - convergence = [] - - if plot_reconstructions: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 4 * num_angle_values) - ) - - fig = plt.figure(figsize=figsize) - - # run loop and plot along the way - self._verbose = False - for flat_index, (angle, defocus) in enumerate( - tqdmnd(angles, defocus_values, desc="Tuning angle and defocus") - ): - self.preprocess( - defocus_guess=defocus, - rotation_guess=angle, - plot_average_bf=False, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_angle_values, num_defocus_values) - ) - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_figax( - fig, - ax=object_ax, - ) - - object_ax.set_title( - f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self._recon_error[0]:.3e}" - ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - recon_BF.append(self.recon_BF) - if return_values or plot_convergence: - convergence.append(asnumpy(self._recon_error[0])) - - if plot_convergence: - fig, ax = plt.subplots() - ax.set_title("convergence") - im = ax.imshow( - np.array(convergence).reshape(angles.shape[0], defocus_values.shape[0]), - cmap="magma", - ) - - if angles.shape[0] > 1: - ax.set_ylabel("angles") - ax.set_yticks(np.arange(angles.shape[0])) - ax.set_yticklabels([f"{angle:.1f} °" for angle in angles]) - else: - ax.set_yticks([]) - ax.set_ylabel(f"angle {angles[0]:.1f}") - - if defocus_values.shape[0] > 1: - ax.set_xlabel("defocus values") - ax.set_xticks(np.arange(defocus_values.shape[0])) - ax.set_xticklabels([f"{df:.1f}" for df in defocus_values]) - else: - ax.set_xticks([]) - ax.set_xlabel(f"defocus value: {defocus_values[0]:.1f}") - - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - fig.colorbar(im, cax=cax) - - fig.tight_layout() - - if return_values: - convergence = np.array(convergence).reshape( - angles.shape[0], defocus_values.shape[0] - ) - return recon_BF, convergence - def reconstruct( self, max_alignment_bin: int = None, min_alignment_bin: int = 1, - max_iter_at_min_bin: int = 2, + num_iter_at_min_bin: int = 2, + alignment_bin_values: list = None, cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = True, @@ -789,6 +744,8 @@ def reconstruct( plot_aligned_bf: bool = True, plot_convergence: bool = True, reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, **kwargs, ): """ @@ -801,8 +758,10 @@ def reconstruct( If None, the bright field disk radius is used min_alignment_bin: int, optional Minimum bin size for bright field alignment - max_iter_at_min_bin: int, optional + num_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size + alignment_bin_values: list, optional + If not None, explicitly sets the iteration bin values cross_correlation_upsample_factor: int, optional DFT upsample factor for subpixel alignment regularizer_matrix_size: Tuple[int,int], optional @@ -819,6 +778,10 @@ def reconstruct( If True, the convergence error is also plotted reset: bool, optional If True, the reconstruction is reset + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls Returns -------- @@ -826,17 +789,22 @@ def reconstruct( Self to accommodate chaining """ + # handle device/storage + self.set_device(device, clear_fft_cache) + xp = self._xp asnumpy = self._asnumpy if reset: + self.error_iterations = [] self._recon_BF = self._recon_BF_initial.copy() - self._stack_BF = self._stack_BF_initial.copy() + self._stack_BF_shifted = self._stack_BF_shifted_initial.copy() self._stack_mask = self._stack_mask_initial.copy() self._recon_mask = self._recon_mask_initial.copy() self._xy_shifts = self._xy_shifts_initial.copy() + elif reset is None: - if hasattr(self, "_basis"): + if hasattr(self, "error_iterations"): warnings.warn( ( "Continuing reconstruction from previous result. " @@ -844,6 +812,8 @@ def reconstruct( ), UserWarning, ) + else: + self.error_iterations = [] if not regularize_shifts: self._basis = self._kxy @@ -890,14 +860,17 @@ def reconstruct( else: max_alignment_bin = diameter_pixels - bin_min = np.ceil(np.log(min_alignment_bin) / np.log(2)) - bin_max = np.ceil(np.log(max_alignment_bin) / np.log(2)) - bin_vals = 2 ** np.arange(bin_min, bin_max)[::-1] + if alignment_bin_values is not None: + bin_vals = np.array(alignment_bin_values).clip(1, max_alignment_bin) + else: + bin_min = np.ceil(np.log(min_alignment_bin) / np.log(2)) + bin_max = np.ceil(np.log(max_alignment_bin) / np.log(2)) + bin_vals = 2 ** np.arange(bin_min, bin_max)[::-1] - if max_iter_at_min_bin > 1: - bin_vals = np.hstack( - (bin_vals, np.repeat(bin_vals[-1], max_iter_at_min_bin - 1)) - ) + if num_iter_at_min_bin > 1: + bin_vals = np.hstack( + (bin_vals, np.repeat(bin_vals[-1], num_iter_at_min_bin - 1)) + ) if plot_aligned_bf: num_plots = bin_vals.shape[0] @@ -905,7 +878,6 @@ def reconstruct( ncols = int(np.ceil(num_plots / nrows)) if plot_convergence: - errors = [] spec = GridSpec( ncols=ncols, nrows=nrows + 1, @@ -914,7 +886,7 @@ def reconstruct( height_ratios=[1] * nrows + [1 / 4], ) - figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows + 1)) + figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows + 1)) else: spec = GridSpec( ncols=ncols, @@ -923,9 +895,8 @@ def reconstruct( wspace=0.15, ) - figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows)) + figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows)) - kwargs.pop("figsize", None) fig = plt.figure(figsize=figsize) xy_center = (self._xy_inds - xp.median(self._xy_inds, axis=0)).astype("float") @@ -958,7 +929,7 @@ def reconstruct( xy_inds[:, 1] == xy_vals[ind_align, 1], ) - G = xp.fft.fft2(xp.mean(self._stack_BF[sub], axis=0)) + G = xp.fft.fft2(xp.mean(self._stack_BF_shifted[sub], axis=0)) # Get best fit alignment xy_shift = align_images_fourier( @@ -970,17 +941,17 @@ def reconstruct( dx = ( xp.mod( - xy_shift[0] + self._stack_BF.shape[1] / 2, - self._stack_BF.shape[1], + xy_shift[0] + self._stack_BF_shifted.shape[1] / 2, + self._stack_BF_shifted.shape[1], ) - - self._stack_BF.shape[1] / 2 + - self._stack_BF_shifted.shape[1] / 2 ) dy = ( xp.mod( - xy_shift[1] + self._stack_BF.shape[2] / 2, - self._stack_BF.shape[2], + xy_shift[1] + self._stack_BF_shifted.shape[2] / 2, + self._stack_BF_shifted.shape[2], ) - - self._stack_BF.shape[2] / 2 + - self._stack_BF_shifted.shape[2] / 2 ) # output shifts @@ -1000,7 +971,7 @@ def reconstruct( shifts_update = xy_shifts_fit - self._xy_shifts # apply shifts - Gs = xp.fft.fft2(self._stack_BF) + Gs = xp.fft.fft2(self._stack_BF_shifted) dx = shifts_update[:, 0] dy = shifts_update[:, 1] @@ -1012,13 +983,13 @@ def reconstruct( + self._qy_shift[None] * dy[:, None, None] ) - self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) + self._stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) - self._stack_BF = xp.asarray( - self._stack_BF, dtype=xp.float32 + self._stack_BF_shifted = xp.asarray( + self._stack_BF_shifted, dtype=xp.float32 ) # numpy fft upcasts? self._stack_mask = xp.asarray( self._stack_mask, dtype=xp.float32 @@ -1029,7 +1000,9 @@ def reconstruct( # Center the shifts xy_shifts_median = xp.round(xp.median(self._xy_shifts, axis=0)).astype(int) self._xy_shifts -= xy_shifts_median[None, :] - self._stack_BF = xp.roll(self._stack_BF, -xy_shifts_median, axis=(1, 2)) + self._stack_BF_shifted = xp.roll( + self._stack_BF_shifted, -xy_shifts_median, axis=(1, 2) + ) self._stack_mask = xp.roll(self._stack_mask, -xy_shifts_median, axis=(1, 2)) # Generate new estimate @@ -1038,18 +1011,21 @@ def reconstruct( mask_inv = 1 - np.clip(self._recon_mask, 0, 1) self._recon_BF = ( self._stack_mean * mask_inv - + xp.sum(self._stack_BF * self._stack_mask, axis=0) + + xp.sum(self._stack_BF_shifted * self._stack_mask, axis=0) ) / (self._recon_mask + mask_inv) self._recon_error = ( xp.atleast_1d( xp.sum( - xp.abs(self._stack_BF - self._recon_BF[None]) * self._stack_mask + xp.abs(self._stack_BF_shifted - self._recon_BF[None]) + * self._stack_mask ) ) / self._mask_sum ) + self.error_iterations.append(float(self._recon_error)) + if plot_aligned_bf: row_index, col_index = np.unravel_index(a0, (nrows, ncols)) @@ -1060,64 +1036,112 @@ def reconstruct( ax.set_yticks([]) ax.set_title(f"Aligned BF at bin {int(bin_vals[a0])}") - if plot_convergence: - errors.append(float(self._recon_error)) - if plot_aligned_bf: if plot_convergence: ax = fig.add_subplot(spec[-1, :]) - ax.plot(np.arange(num_plots), errors) - ax.set_xticks(np.arange(num_plots)) + x_range = np.arange(len(self.error_iterations)) + ax.plot(x_range, self.error_iterations) + ax.set_xticks(x_range) ax.set_ylabel("Error") spec.tight_layout(fig) self.recon_BF = asnumpy(self._recon_BF) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) return self def subpixel_alignment( self, + virtual_detector_mask=None, kde_upsample_factor=None, - kde_sigma=0.125, + kde_sigma_px=0.125, + kde_lowpass_filter=False, + lanczos_interpolation_order=None, + integer_pixel_rolling_alignment=False, plot_upsampled_BF_comparison: bool = True, plot_upsampled_FFT_comparison: bool = False, + position_correction_num_iter=None, + position_correction_initial_step_size=1.0, + position_correction_min_step_size=0.1, + position_correction_step_size_factor=0.75, + position_correction_checkerboard_steps=False, + position_correction_gaussian_filter_sigma=None, + position_correction_butterworth_q_lowpass=None, + position_correction_butterworth_q_highpass=None, + position_correction_butterworth_order=(2, 2), + plot_position_correction_convergence: bool = True, + progress_bar: bool = True, **kwargs, ): """ Upsample and subpixel-align BFs using the measured image shifts. - Uses kernel density estimation (KDE) to align upsampled BFs. + Uses kernel density estimation (KDE) to interpolate the upsampled BFs. Parameters ---------- + virtual_detector_mask: np.ndarray, bool + Virtual detector mask, as a boolean array the same size as dp_mask kde_upsample_factor: int, optional Real-space upsampling factor - kde_sigma: float, optional - KDE gaussian kernel bandwidth + kde_sigma_px: float, optional + KDE gaussian kernel bandwidth in non-upsampled pixels + kde_lowpass_filter: bool, optional + If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function + lanczos_interpolation_order: int, optional + If not None, Lanczos interpolation with the specified order is used instead of bilinear + fourier_upsampling_additional_factor: int, optional + If not None, Fourier upsampling with integer rolling is used instead of bilinear/Lanczos plot_upsampled_BF_comparison: bool, optional If True, the pre/post alignment BF images are plotted for comparison plot_upsampled_FFT_comparison: bool, optional If True, the pre/post alignment BF FFTs are plotted for comparison + position_correction_num_iter: int, optional + If not None, parallax positions are corrected iteratively for this many iterations + position_correction_initial_step_size: float, optional + Initial position correction step-size in pixels + position_correction_min_step_size: float, optional + Minimum position correction step-size in pixels + position_correction_step_size_factor: float, optional + Factor to multiply step-size by between iterations + position_correction_checkerboard_steps: bool, optional + If True, uses steepest-descent checkerboarding steps, as opposed to gradient direction + position_correction_gaussian_filter_sigma: tuple(float, float), optional + Standard deviation of gaussian kernel in A + position_correction_butterworth_q_lowpass: tuple(float, float), optional + Cut-off frequency in A^-1 for low-pass butterworth filter + position_correction_butterworth_q_highpass: tuple(float, float), optional + Cut-off frequency in A^-1 for high-pass butterworth filter + position_correction_butterworth_order: tuple(int,int), optional + Butterworth filter order. Smaller gives a smoother filter + plot_position_correction_convergence: bool, optional + If True, position correction convergence is plotted + progress_bar: bool, optional + If True, a progress bar is printed with position correction progress """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + gaussian_filter = self._scipy.ndimage.gaussian_filter - xy_shifts = self._xy_shifts - BF_size = np.array(self._stack_BF_no_window.shape[-2:]) - - self._DF_upsample_limit = np.max( - 2 * self._region_of_interest_shape / self._scan_shape + BF_sampling = 1 / asnumpy(self._kr).max() / 2 + DF_sampling = 1 / ( + self._reciprocal_sampling[0] * self._region_of_interest_shape[0] ) - self._BF_upsample_limit = ( - 4 * self._kr.max() / self._reciprocal_sampling[0] - ) / self._scan_shape.max() - if self._device == "gpu": - self._BF_upsample_limit = self._BF_upsample_limit.item() + + self._BF_upsample_limit = self._scan_sampling[0] / BF_sampling + self._DF_upsample_limit = self._scan_sampling[0] / DF_sampling + + if self._DF_upsample_limit < 1: + warnings.warn( + ( + f"Dark-field upsampling limit of {self._DF_upsample_limit:.2f} " + "is less than 1, implying a scan step-size smaller than Nyquist. " + "setting to 1." + ), + UserWarning, + ) + self._DF_upsample_limit = 1 if kde_upsample_factor is None: if self._BF_upsample_limit * 3 / 2 > self._DF_upsample_limit: @@ -1126,7 +1150,7 @@ def subpixel_alignment( warnings.warn( ( f"Upsampling factor set to {kde_upsample_factor:.2f} (the " - f"dark-field upsampling limit)." + "dark-field upsampling limit)." ), UserWarning, ) @@ -1142,7 +1166,7 @@ def subpixel_alignment( UserWarning, ) else: - kde_upsample_factor = self._DF_upsample_limit * 2 / 3 + kde_upsample_factor = np.maximum(self._DF_upsample_limit * 2 / 3, 1) warnings.warn( ( @@ -1165,129 +1189,451 @@ def subpixel_alignment( ) self._kde_upsample_factor = kde_upsample_factor - pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int") - pixel_size = pixel_output.prod() - - # shifted coordinates - x = xp.arange(BF_size[0]) - y = xp.arange(BF_size[1]) - - xa, ya = xp.meshgrid(x, y, indexing="ij") - xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel() - ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel() - - # bilinear sampling - xF = xp.floor(xa).astype("int") - yF = xp.floor(ya).astype("int") - dx = xa - xF - dy = ya - yF - - # resampling - inds_1D = xp.ravel_multi_index( - xp.hstack( - [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] - ), - pixel_output, - mode=["wrap", "wrap"], - ) - weights = xp.hstack( - ( - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), + # virtual detector + if virtual_detector_mask is None: + xy_shifts = self._xy_shifts + stack_BF_unshifted = self._stack_BF_unshifted + else: + virtual_detector_mask = np.asarray(virtual_detector_mask, dtype="bool") + xy_inds_np = asnumpy(self._xy_inds) + inds = virtual_detector_mask[xy_inds_np[:, 0], xy_inds_np[:, 1]] + + xy_shifts = self._xy_shifts[inds] + stack_BF_unshifted = self._stack_BF_unshifted[inds] + + BF_size = np.array(stack_BF_unshifted.shape[-2:]) + pixel_output_shape = np.round(BF_size * self._kde_upsample_factor).astype("int") + + if ( + not integer_pixel_rolling_alignment + or position_correction_num_iter is not None + ): + # shifted coordinates + x = xp.arange(BF_size[0], dtype=xp.float32) + y = xp.arange(BF_size[1], dtype=xp.float32) + xa_init, ya_init = xp.meshgrid(x, y, indexing="ij") + + # kernel density output the upsampled BF image + xa = (xa_init + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor + ya = (ya_init + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor + + pix_output = self._kernel_density_estimate( + xa, + ya, + stack_BF_unshifted, + pixel_output_shape, + kde_sigma_px * self._kde_upsample_factor, + lanczos_alpha=lanczos_interpolation_order, + lowpass_filter=kde_lowpass_filter, ) - ) + else: + upsample_fraction, upsample_int = np.modf(self._kde_upsample_factor) - pix_count = xp.reshape( - xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output - ) - pix_output = xp.reshape( - xp.bincount( - inds_1D, - weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4), - minlength=pixel_size, - ), - pixel_output, - ) + if upsample_fraction: + upsample_nearest = np.round(self._kde_upsample_factor).astype("int") + + warnings.warn( + ( + f"Upsampling factor of {self._kde_upsample_factor} " + f"rounded to nearest integer {upsample_nearest}." + ), + UserWarning, + ) + + self._kde_upsample_factor = upsample_nearest + + pix_output = pixel_rolling_kernel_density_estimate( + stack_BF_unshifted, + xy_shifts, + self._kde_upsample_factor, + kde_sigma_px * self._kde_upsample_factor, + xp=xp, + gaussian_filter=gaussian_filter, + ) + + # Perform probe position correction if needed + if position_correction_num_iter is not None: + if integer_pixel_rolling_alignment: + interpolation_method = ( + "bilinear" if lanczos_interpolation_order is None else "Lanczos" + ) + warnings.warn( + ( + "Integer pixel rolling is not compatible with position-correction, " + f"{interpolation_method} KDE interpolation will be used instead." + ), + UserWarning, + ) + + recon_BF_subpixel_aligned_reference = pix_output.copy() + + # init position shift array + self._probe_dx = xp.zeros_like(xa_init) + self._probe_dy = xp.zeros_like(xa_init) + + # step size of initial search, cost function + step = xp.ones_like(xa_init) * position_correction_initial_step_size + + # init scores and stats + position_correction_stats = np.zeros(position_correction_num_iter + 1) + + scores = ( + xp.mean( + xp.abs( + self._interpolate_array( + pix_output, + xa, + ya, + lanczos_alpha=None, + ) + - stack_BF_unshifted + ), + axis=0, + ) + * self._window_pad + ) + + position_correction_stats[0] = scores.mean() + + # gradient search directions + + if position_correction_checkerboard_steps: + # checkerboard steps + dxy = np.array( + [ + [-1.0, 0.0], + [1.0, 0.0], + [0.0, -1.0], + [0.0, 1.0], + ] + ) + + else: + # centered finite-difference directions + dxy = np.array( + [ + [-0.5, 0.0], + [0.5, 0.0], + [0.0, -0.5], + [0.0, 0.5], + ] + ) + + scores_test = xp.zeros( + ( + dxy.shape[0], + scores.shape[0], + scores.shape[1], + ) + ) + + # main loop for position correction + for a0 in tqdmnd( + position_correction_num_iter, + desc="Correcting positions: ", + unit=" iteration", + disable=not progress_bar, + ): + # Evaluate scores for step directions and magnitudes + + for a1 in range(dxy.shape[0]): + xa = ( + xa_init + + self._probe_dx + + dxy[a1, 0] * step + + xy_shifts[:, 0, None, None] + ) * self._kde_upsample_factor + + ya = ( + ya_init + + self._probe_dy + + dxy[a1, 1] * step + + xy_shifts[:, 1, None, None] + ) * self._kde_upsample_factor + + scores_test[a1] = xp.mean( + xp.abs( + self._interpolate_array( + pix_output, + xa, + ya, + lanczos_alpha=None, + ) + - stack_BF_unshifted + ), + axis=0, + ) + + if position_correction_checkerboard_steps: + # Check where cost function has improved + + scores_test *= self._window_pad[None] + update = np.min(scores_test, axis=0) < scores + scores_ind = np.argmin(scores_test, axis=0) + + for a1 in range(dxy.shape[0]): + sub = np.logical_and(update, scores_ind == a1) + self._probe_dx[sub] += ( + dxy[a1, 0] * step[sub] * self._window_pad[sub] + ) + self._probe_dy[sub] += ( + dxy[a1, 1] * step[sub] * self._window_pad[sub] + ) + + else: + # Check where cost function has improved + dx = scores_test[0] - scores_test[1] + dy = scores_test[2] - scores_test[3] + + dr = xp.sqrt(dx**2 + dy**2) / step + dx *= self._window_pad / dr + dy *= self._window_pad / dr + + # Fixed-size step + xa = ( + xa_init + self._probe_dx + dx + xy_shifts[:, 0, None, None] + ) * self._kde_upsample_factor + + ya = ( + ya_init + self._probe_dy + dy + xy_shifts[:, 1, None, None] + ) * self._kde_upsample_factor + + fixed_step_scores = ( + xp.mean( + xp.abs( + self._interpolate_array( + pix_output, + xa, + ya, + lanczos_alpha=None, + ) + - stack_BF_unshifted + ), + axis=0, + ) + * self._window_pad + ) - # kernel density estimate - sigma = kde_sigma * self._kde_upsample_factor - pix_count = gaussian_filter(pix_count, sigma) - pix_count[pix_count == 0.0] = np.inf - pix_output = gaussian_filter(pix_output, sigma) - pix_output /= pix_count + update = fixed_step_scores < scores + self._probe_dx[update] += dx[update] + self._probe_dy[update] += dy[update] + + # reduce gradient step for sites which did not improve + step[xp.logical_not(update)] *= position_correction_step_size_factor + + # enforce minimum step size + step = xp.maximum(step, position_correction_min_step_size) + + # apply regularization if needed + if position_correction_gaussian_filter_sigma is not None: + self._probe_dx = gaussian_filter( + self._probe_dx, + position_correction_gaussian_filter_sigma[0] + / self._scan_sampling[0], + # mode="nearest", + ) + self._probe_dy = gaussian_filter( + self._probe_dy, + position_correction_gaussian_filter_sigma[1] + / self._scan_sampling[1], + # mode="nearest", + ) + + if ( + position_correction_butterworth_q_lowpass is not None + or position_correction_butterworth_q_highpass is not None + ): + qx = xp.fft.fftfreq(BF_size[0], self._scan_sampling[0]) + qy = xp.fft.fftfreq(BF_size[1], self._scan_sampling[1]) + + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + if position_correction_butterworth_q_lowpass: + ( + q_lowpass_x, + q_lowpass_y, + ) = position_correction_butterworth_q_lowpass + else: + q_lowpass_x, q_lowpass_y = (None, None) + if position_correction_butterworth_q_highpass: + ( + q_highpass_x, + q_highpass_y, + ) = position_correction_butterworth_q_highpass + else: + q_highpass_x, q_highpass_y = (None, None) + + order_x, order_y = position_correction_butterworth_order + + # dx + env = xp.ones_like(qra) + if q_highpass_x: + env *= 1 - 1 / (1 + (qra / q_highpass_x) ** (2 * order_x)) + if q_lowpass_x: + env *= 1 / (1 + (qra / q_lowpass_x) ** (2 * order_x)) + + probe_dx_mean = xp.mean(self._probe_dx) + self._probe_dx -= probe_dx_mean + self._probe_dx = xp.real( + xp.fft.ifft2(xp.fft.fft2(self._probe_dx) * env) + ) + self._probe_dx += probe_dx_mean + + # dy + env = xp.ones_like(qra) + if q_highpass_y: + env *= 1 - 1 / (1 + (qra / q_highpass_y) ** (2 * order_y)) + if q_lowpass_y: + env *= 1 / (1 + (qra / q_lowpass_y) ** (2 * order_y)) + + probe_dy_mean = xp.mean(self._probe_dy) + self._probe_dy -= probe_dy_mean + self._probe_dy = xp.real( + xp.fft.ifft2(xp.fft.fft2(self._probe_dy) * env) + ) + self._probe_dy += probe_dy_mean + + # kernel density output the upsampled BF image + xa = ( + xa_init + self._probe_dx + xy_shifts[:, 0, None, None] + ) * self._kde_upsample_factor + + ya = ( + ya_init + self._probe_dy + xy_shifts[:, 1, None, None] + ) * self._kde_upsample_factor + + pix_output = self._kernel_density_estimate( + xa, + ya, + stack_BF_unshifted, + pixel_output_shape, + kde_sigma_px * self._kde_upsample_factor, + lanczos_alpha=lanczos_interpolation_order, + lowpass_filter=kde_lowpass_filter, + ) + + # update cost function and stats + scores = ( + xp.mean( + xp.abs( + self._interpolate_array( + pix_output, + xa, + ya, + lanczos_alpha=None, + ) + - stack_BF_unshifted + ), + axis=0, + ) + * self._window_pad + ) + + position_correction_stats[a0 + 1] = scores.mean() + + else: + plot_position_correction_convergence = False self._recon_BF_subpixel_aligned = pix_output self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + # plotting - if plot_upsampled_BF_comparison: - if plot_upsampled_FFT_comparison: - figsize = kwargs.pop("figsize", (8, 8)) - fig, axs = plt.subplots(2, 2, figsize=figsize) - else: - figsize = kwargs.pop("figsize", (8, 4)) - fig, axs = plt.subplots(1, 2, figsize=figsize) + nrows = np.count_nonzero( + np.array( + [ + plot_upsampled_BF_comparison, + plot_upsampled_FFT_comparison, + plot_position_correction_convergence, + ] + ) + ) + if nrows > 0: + ncols = 3 if position_correction_num_iter is not None else 2 + height_ratios = ( + [4, 4, 2][-nrows:] + if plot_position_correction_convergence + else [4, 4, 2][:nrows] + ) + spec = GridSpec( + ncols=ncols, nrows=nrows, height_ratios=height_ratios, hspace=0.15 + ) - axs = axs.flat + figsize = kwargs.pop("figsize", (4 * ncols, sum(height_ratios))) cmap = kwargs.pop("cmap", "magma") + fig = plt.figure(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_BF) - cropped_object_aligned = self._crop_padded_object( - self._recon_BF_subpixel_aligned, upsampled=True - ) + row_index = 0 - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + if plot_upsampled_BF_comparison: + ax1 = fig.add_subplot(spec[row_index, 0]) + ax2 = fig.add_subplot(spec[row_index, 1]) - axs[0].imshow( - cropped_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - axs[0].set_title("Aligned Bright Field") + cropped_object = self._crop_padded_object(self._recon_BF) - axs[1].imshow( - cropped_object_aligned, - extent=extent, - cmap=cmap, - **kwargs, - ) - axs[1].set_title("Upsampled Bright Field") + if ncols == 3: + ax3 = fig.add_subplot(spec[row_index, 2]) - for ax in axs[:2]: - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") + cropped_object_reference_aligned = self._crop_padded_object( + recon_BF_subpixel_aligned_reference, upsampled=True + ) + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + axs = [ax1, ax2, ax3] - if plot_upsampled_FFT_comparison: - recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) - pad_x = np.round( - BF_size[0] * (self._kde_upsample_factor - 1) / 2 - ).astype("int") - pad_y = np.round( - BF_size[1] * (self._kde_upsample_factor - 1) / 2 - ).astype("int") - pad_recon_fft = asnumpy( - xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + else: + cropped_object_reference_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + axs = [ax1, ax2] + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, ) + axs[0].set_title("Aligned Bright Field") - upsampled_fft = asnumpy( - xp.fft.fftshift( - xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - ) + axs[1].imshow( + cropped_object_reference_aligned, + extent=extent, + cmap=cmap, + **kwargs, ) + axs[1].set_title("Upsampled Bright Field") + + if ncols == 3: + axs[2].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[2].set_title("Probe-Corrected Bright Field") + + for ax in axs: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + row_index += 1 + + if plot_upsampled_FFT_comparison: + ax1 = fig.add_subplot(spec[row_index, 0]) + ax2 = fig.add_subplot(spec[row_index, 1]) reciprocal_extent = [ -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), @@ -1296,30 +1642,180 @@ def subpixel_alignment( -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), ] - show( + nx, ny = self._recon_BF_subpixel_aligned.shape + kx = xp.fft.fftfreq(nx, d=1) + ky = xp.fft.fftfreq(ny, d=1) + k = xp.fft.fftshift(xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2)) + + recon_fft = xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF)) / np.prod(self._recon_BF.shape) + ) + sx, sy = recon_fft.shape + + pad_x_post = (nx - sx) // 2 + pad_x_pre = nx - sx - pad_x_post + pad_y_post = (ny - sy) // 2 + pad_y_pre = ny - sy - pad_y_post + + pad_recon_fft = asnumpy( + xp.pad( + recon_fft, ((pad_x_pre, pad_x_post), (pad_y_pre, pad_y_post)) + ) + * k + ) + + if ncols == 3: + ax3 = fig.add_subplot(spec[row_index, 2]) + upsampled_fft_reference = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(recon_BF_subpixel_aligned_reference)) + / (nx * ny) + ) + * k + ) + + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + / (nx * ny) + ) + * k + ) + axs = [ax1, ax2, ax3] + else: + upsampled_fft_reference = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + / (nx * ny) + ) + * k + ) + axs = [ax1, ax2] + + _, vmin, vmax = return_scaled_histogram_ordering( + upsampled_fft_reference + ) + + axs[0].imshow( pad_recon_fft, - figax=(fig, axs[2]), extent=reciprocal_extent, + vmin=vmin, + vmax=vmax, cmap="gray", - title="Aligned Bright Field FFT", **kwargs, ) + axs[0].set_title("Aligned Bright Field FFT") - show( - upsampled_fft, - figax=(fig, axs[3]), + axs[1].imshow( + upsampled_fft_reference, extent=reciprocal_extent, + vmin=vmin, + vmax=vmax, cmap="gray", - title="Upsampled Bright Field FFT", **kwargs, ) + axs[1].set_title("Upsampled Bright Field FFT") + + if ncols == 3: + axs[2].imshow( + upsampled_fft, + extent=reciprocal_extent, + vmin=vmin, + vmax=vmax, + cmap="gray", + **kwargs, + ) + axs[2].set_title("Probe-Corrected Bright Field FFT") - for ax in axs[2:]: + for ax in axs: ax.set_ylabel(r"$k_x$ [$A^{-1}$]") ax.set_xlabel(r"$k_y$ [$A^{-1}$]") - ax.xaxis.set_ticks_position("bottom") - fig.tight_layout() + row_index += 1 + + if plot_position_correction_convergence: + axs = fig.add_subplot(spec[row_index, :]) + + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + color = kwargs.pop("color", (1, 0, 0)) + + axs.semilogy( + np.arange(position_correction_num_iter + 1), + position_correction_stats / position_correction_stats[0], + color=color, + **kwargs, + ) + axs.set_xlabel("Iteration number") + axs.set_ylabel("NMSE") + axs.yaxis.set_major_formatter(PercentFormatter(1.0, decimals=0)) + axs.yaxis.set_minor_formatter(PercentFormatter(1.0, decimals=0)) + + spec.tight_layout(fig) + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def _interpolate_array( + self, + image, + xa, + ya, + lanczos_alpha, + ): + """ """ + + xp = self._xp + + if lanczos_alpha is not None: + return lanczos_interpolate_array(image, xa, ya, lanczos_alpha, xp=xp) + else: + return bilinearly_interpolate_array( + image, + xa, + ya, + xp=xp, + ) + + def _kernel_density_estimate( + self, + xa, + ya, + intensities, + output_shape, + kde_sigma, + lanczos_alpha=None, + lowpass_filter=False, + ): + """ """ + + xp = self._xp + gaussian_filter = self._scipy.ndimage.gaussian_filter + + if lanczos_alpha is not None: + return lanczos_kernel_density_estimate( + xa, + ya, + intensities, + output_shape, + kde_sigma, + lanczos_alpha, + lowpass_filter=lowpass_filter, + xp=xp, + gaussian_filter=gaussian_filter, + ) + else: + return bilinear_kernel_density_estimate( + xa, + ya, + intensities, + output_shape, + kde_sigma, + lowpass_filter=lowpass_filter, + xp=xp, + gaussian_filter=gaussian_filter, + ) def aberration_fit( self, @@ -1329,13 +1825,15 @@ def aberration_fit( fit_aberrations_max_angular_order: int = 4, fit_aberrations_min_radial_order: int = 2, fit_aberrations_min_angular_order: int = 0, + fit_aberrations_mn: list = None, fit_max_thon_rings: int = 6, - fit_power_alpha: float = 2.0, + fit_power_alpha: float = 1.0, plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, upsampled: bool = True, force_transpose: bool = False, force_rotation_deg: float = None, + **kwargs, ): """ Fit aberrations to the measured image shifts. @@ -1355,6 +1853,8 @@ def aberration_fit( Min radial order for fitting of aberrations. fit_aberrations_min_angular_order: int Min angular order for fitting of aberrations. + fit_aberrations_mn: list + If not None, sets aberrations mn explicitly. fit_max_thon_rings: int Max number of Thon rings to search for during CTF FFT fitting. fit_power_alpha: int @@ -1432,17 +1932,21 @@ def aberration_fit( ### Second pass # Aberration coefs - mn = [] - for m in range( - fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order - ): - n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) - for n in range(fit_aberrations_min_angular_order, n_max + 1): - if (m + n) % 2: - mn.append([m, n, 0]) - if n > 0: - mn.append([m, n, 1]) + if fit_aberrations_mn is None: + mn = [] + + for m in range( + fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order + ): + n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) + for n in range(fit_aberrations_min_angular_order, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + else: + mn = fit_aberrations_mn self._aberrations_mn = np.array(mn) self._aberrations_mn = self._aberrations_mn[ @@ -1458,14 +1962,6 @@ def aberration_fit( ] self._aberrations_num = self._aberrations_mn.shape[0] - if plot_CTF_comparison is None: - if fit_CTF_FFT: - plot_CTF_comparison = True - - if plot_BF_shifts_comparison is None: - if fit_BF_shifts: - plot_BF_shifts_comparison = True - # Thon Rings Fitting if fit_CTF_FFT or plot_CTF_comparison: if upsampled and hasattr(self, "_kde_upsample_factor"): @@ -1699,16 +2195,16 @@ def score_CTF(coefs): measured_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 0] + measured_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 0] + ) measured_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._xy_shifts_Ang[:, 1] + measured_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + self._xy_shifts_Ang[:, 1] + ) fitted_shifts = ( xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) @@ -1719,16 +2215,16 @@ def score_CTF(coefs): fitted_shifts_sx = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 0] + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 0] + ) fitted_shifts_sy = xp.zeros( self._region_of_interest_shape, dtype=xp.float32 ) - fitted_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = fitted_shifts[:, 1] + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = ( + fitted_shifts[:, 1] + ) max_shift = xp.max( xp.array( @@ -1741,16 +2237,21 @@ def score_CTF(coefs): ) ) + axsize = kwargs.pop("axsize", (4, 4)) + cmap = kwargs.pop("cmap", "PiYG") + vmin = kwargs.pop("vmin", -max_shift) + vmax = kwargs.pop("vmax", max_shift) + show( [ [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], ], - cmap="PiYG", - vmin=-max_shift, - vmax=max_shift, + cmap=cmap, + vmin=vmin, + vmax=vmax, intensity_range="absolute", - axsize=(4, 4), + axsize=axsize, ticks=False, title=[ "Measured Vertical Shifts", @@ -1785,7 +2286,9 @@ def score_CTF(coefs): im_CTF = calculate_CTF_FFT( self._aberrations_surface_shape_FFT, *self._aberrations_coefs ) - im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 + + im_CTF_plot = xp.abs(xp.sin(im_CTF)) + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 im_CTF[xp.logical_not(plot_mask)] = 0 @@ -1801,7 +2304,7 @@ def score_CTF(coefs): im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent ) ax2.imshow( - np.fft.fftshift(asnumpy(im_CTF_cos)), + np.fft.fftshift(asnumpy(im_CTF_plot)), cmap="gray", extent=reciprocal_extent, ) @@ -1811,11 +2314,11 @@ def score_CTF(coefs): ax.set_xlabel(r"$k_y$ [$A^{-1}$]") ax1.set_title("Aligned Bright Field FFT") - ax2.set_title("Fitted CTF Zero-Crossings") + ax2.set_title("Fitted CTF ") fig.tight_layout() - self.aberration_dict = { + self.aberration_dict_cartesian = { tuple(self._aberrations_mn[a0]): { "aberration name": _aberration_names.get( tuple(self._aberrations_mn[a0, :2]), "-" @@ -1887,9 +2390,9 @@ def score_CTF(coefs): + str(np.round(self._aberrations_coefs[a0]).astype("int")) ) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self def _calculate_CTF(self, alpha_shape, sampling, *coefs): xp = self._xp @@ -1938,9 +2441,6 @@ def aberration_correct( plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, - Wiener_filter=False, - Wiener_signal_noise_ratio: float = 1.0, - Wiener_filter_low_only: bool = False, upsampled: bool = True, **kwargs, ): @@ -1958,12 +2458,6 @@ def aberration_correct( maximum allowed frequency in butterworth filter k_info_power: float, optional power of butterworth filter - Wiener_filter: bool, optional - Use Wiener filtering instead of CTF sign correction. - Wiener_signal_noise_ratio: float, optional - Signal to noise radio at k = 0 for Wiener filter - Wiener_filter_low_only: bool, optional - Apply Wiener filtering only to the CTF portions before the 1st CTF maxima. """ xp = self._xp @@ -1997,68 +2491,40 @@ def aberration_correct( use_CTF_fit = True if use_CTF_fit: - sin_chi = np.sin( - self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs) - ) + even_radial_orders = (self._aberrations_mn[:, 0] % 2) == 1 + odd_radial_orders = (self._aberrations_mn[:, 0] % 2) == 0 - CTF_corr = xp.sign(sin_chi) - CTF_corr[0, 0] = 0 + odd_coefs = self._aberrations_coefs.copy() + odd_coefs[even_radial_orders] = 0 + chi_odd = self._calculate_CTF(im.shape, (sx, sy), *odd_coefs) - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(im) * CTF_corr + even_coefs = self._aberrations_coefs.copy() + even_coefs[odd_radial_orders] = 0 + chi_even = self._calculate_CTF(im.shape, (sx, sy), *even_coefs) - # if needed, add low pass filter output image - if k_info_limit is not None: - im_fft_corr /= 1 + (kra2**k_info_power) / ( - (k_info_limit) ** (2 * k_info_power) - ) - else: - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + if not chi_even.any(): # check if all zeros + chi_even = xp.ones_like(chi_even) - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 - + (kra2**k_info_power) - / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio - ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(im) * CTF_corr + else: + chi_even = (xp.pi * self._wavelength * self.aberration_C1) * kra2 + chi_odd = xp.zeros_like(chi_even) - else: - # CTF without tilt correction (beyond the parallax operator) - CTF_corr = xp.sign(sin_chi) - CTF_corr[0, 0] = 0 + CTF_corr = xp.sign(xp.sin(chi_even)) * xp.exp(-1j * chi_odd) + CTF_corr[0, 0] = 0 - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(im) * CTF_corr + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr - # if needed, add low pass filter output image - if k_info_limit is not None: - im_fft_corr /= 1 + (kra2**k_info_power) / ( - (k_info_limit) ** (2 * k_info_power) - ) + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) self.recon_phase_corrected = asnumpy(self._recon_phase_corrected) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - # plotting if plot_corrected_phase: figsize = kwargs.pop("figsize", (6, 6)) @@ -2088,9 +2554,13 @@ def aberration_correct( ax.set_xlabel("y [A]") ax.set_title("Parallax-Corrected Phase Image") + self.clear_device_mem(self._device, self._clear_fft_cache) + return self + def depth_section( self, - depth_angstroms=np.arange(-250, 260, 100), + depth_angstroms=None, + use_CTF_fit=True, plot_depth_sections=True, k_info_limit: float = None, k_info_power: float = 1.0, @@ -2119,7 +2589,6 @@ def depth_section( xp = self._xp asnumpy = self._asnumpy - depth_angstroms = xp.atleast_1d(depth_angstroms) if not hasattr(self, "aberration_C1"): raise ValueError( @@ -2129,16 +2598,26 @@ def depth_section( ) ) + if depth_angstroms is None: + depth_angstroms = np.linspace(-256, 256, 33) + depth_angstroms = xp.atleast_1d(depth_angstroms) + # Fourier coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + sx, sy = self._scan_sampling + nx, ny = self._recon_BF.shape + kx = xp.fft.fftfreq(nx, sx) + ky = xp.fft.fftfreq(ny, sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - # information limit - if k_info_limit is not None: - k_filt = 1 / ( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) + if use_CTF_fit: + sin_chi = xp.sin( + self._calculate_CTF((nx, ny), (sx, sy), *self._aberrations_coefs) ) + else: + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 # init stack_depth = xp.zeros( @@ -2173,27 +2652,21 @@ def depth_section( dz = depth_angstroms[a0] # Parallax - im_depth = xp.zeros_like(self._recon_BF, dtype="complex") - for a1 in range(self._stack_BF.shape[0]): - dx = self._probe_angles[a1, 0] * dz - dy = self._probe_angles[a1, 1] * dz - im_depth += xp.fft.fft2(self._stack_BF[a1]) * xp.exp( - self._qx_shift * dx + self._qy_shift * dy - ) - - # CTF correction - sin_chi = xp.sin( - (xp.pi * self._wavelength * (self.aberration_C1 + dz)) * kra2 + im_depth = xp.zeros_like(self._recon_BF, dtype=xp.complex64) + dx = -self._probe_angles[:, 0] * dz / self._scan_sampling[0] + dy = -self._probe_angles[:, 1] * dz / self._scan_sampling[1] + shift_op = xp.exp( + self._qx_shift[None] * dx[:, None, None] + + self._qy_shift[None] * dy[:, None, None] ) - CTF_corr = xp.sign(sin_chi) - CTF_corr[0, 0] = 0 + im_depth = xp.fft.fft2(self._stack_BF_shifted) * shift_op * CTF_corr + if k_info_limit is not None: - CTF_corr *= k_filt + im_depth /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) - # apply correction to mean reconstructed BF image - stack_depth[a0] = ( - xp.real(xp.fft.ifft2(im_depth * CTF_corr)) / self._stack_BF.shape[0] - ) + stack_depth[a0] = xp.real(xp.fft.ifft2(im_depth)).mean(0) if plot_depth_sections: row_index, col_index = np.unravel_index(a0, (nrows, ncols)) @@ -2217,13 +2690,9 @@ def depth_section( ax.set_xticks([]) ax.set_yticks([]) - ax.set_title(f"Depth section: {dz}A") - - if self._device == "gpu": - xp = self._xp - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + ax.set_title(f"Depth section: {dz} A") + self.clear_device_mem(self._device, self._clear_fft_cache) return stack_depth def _crop_padded_object( @@ -2253,19 +2722,33 @@ def _crop_padded_object( if upsampled: pad_x = np.round( + self._object_padding_px[0] * self._kde_upsample_factor + ).astype("int") + pad_x_left = np.round( self._object_padding_px[0] / 2 * self._kde_upsample_factor ).astype("int") + pad_x_right = pad_x - pad_x_left + pad_y = np.round( + self._object_padding_px[1] * self._kde_upsample_factor + ).astype("int") + pad_y_left = np.round( self._object_padding_px[1] / 2 * self._kde_upsample_factor ).astype("int") + pad_y_right = pad_y - pad_y_left + else: - pad_x = self._object_padding_px[0] // 2 - pad_y = self._object_padding_px[1] // 2 + pad_x_left = self._object_padding_px[0] // 2 + pad_x_right = self._object_padding_px[0] - pad_x_left + pad_y_left = self._object_padding_px[1] // 2 + pad_y_right = self._object_padding_px[1] - pad_y_left - pad_x -= remaining_padding - pad_y -= remaining_padding + pad_x_left -= remaining_padding + pad_x_right -= remaining_padding + pad_y_left -= remaining_padding + pad_y_right -= remaining_padding - return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) + return asnumpy(padded_object[pad_x_left:-pad_x_right, pad_y_left:-pad_y_right]) def _visualize_figax( self, @@ -2364,7 +2847,8 @@ def show_shifts( dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( - xp.arange(self._dp_mean.shape[1]), xp.arange(self._dp_mean.shape[0]) + xp.arange(self._region_of_interest_shape[1]), + xp.arange(self._region_of_interest_shape[0]), ) freq_mask = xp.logical_and(xx % plot_arrow_freq == 0, yy % plot_arrow_freq == 0) masked_ind = xp.logical_and(freq_mask, self._dp_mask) @@ -2433,6 +2917,48 @@ def show_shifts( fig.tight_layout() + def show_probe_position_shifts( + self, + **kwargs, + ): + """ + Utility function to visualize probe-position shifts. + """ + probe_dx = self._crop_padded_object(self._probe_dx) + probe_dy = self._crop_padded_object(self._probe_dy) + max_shift = np.abs(np.dstack((probe_dx, probe_dy))).max() + + figsize = kwargs.pop("figsize", (9, 4)) + vmin = kwargs.pop("vmin", -max_shift) + vmax = kwargs.pop("vmax", max_shift) + cmap = kwargs.pop("cmap", "PuOr") + + extent = [ + 0, + self._scan_sampling[1] * probe_dx.shape[1], + self._scan_sampling[0] * probe_dx.shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + im1 = ax1.imshow(probe_dx, extent=extent, vmin=vmin, vmax=vmax, cmap=cmap) + im2 = ax2.imshow(probe_dy, extent=extent, vmin=vmin, vmax=vmax, cmap=cmap) + + for ax, im in zip([ax1, ax2], [im1, im2]): + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + cb = fig.colorbar(im, cax=ax_cb) + cb.set_label("pix", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + ax1.set_title("Probe Position Vertical Shifts") + ax2.set_title("Probe Position Horizontal Shifts") + + fig.tight_layout() + def visualize( self, **kwargs, @@ -2475,3 +3001,26 @@ def object_cropped(self): ) else: return self._crop_padded_object(self._recon_BF) + + @property + def aberration_dict_polar(self): + """converts cartesian aberration dictionary to the polar convention used in ptycho""" + polar_dict = {} + unique_aberrations = np.unique(self._aberrations_mn[:, :2], axis=0) + aberrations_dict = self.aberration_dict_cartesian + + for aberration_order in unique_aberrations: + m, n = aberration_order + modulus_name = "C" + str(m) + str(n) + + if n != 0: + value_a = aberrations_dict[(m, n, 0)]["value [Ang]"] + value_b = aberrations_dict[(m, n, 1)]["value [Ang]"] + polar_dict[modulus_name] = np.sqrt(value_a**2 + value_b**2) + + argument_name = "phi" + str(m) + str(n) + polar_dict[argument_name] = np.arctan2(value_b, value_a) / n + else: + polar_dict[modulus_name] = aberrations_dict[(m, n, 0)]["value [Ang]"] + + return polar_dict diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py index 91a71cb30..652e1046e 100644 --- a/py4DSTEM/process/phase/parameter_optimize.py +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -1,10 +1,11 @@ from functools import partial +from itertools import product from typing import Callable, Union import matplotlib.pyplot as plt import numpy as np from matplotlib.gridspec import GridSpec -from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import AffineTransform from skopt import gp_minimize from skopt.plots import plot_convergence as skopt_plot_convergence @@ -102,6 +103,151 @@ def __init__( self._set_optimizer_defaults() + def _generate_inclusive_boundary_grid( + self, + parameter, + n_points, + ): + """ """ + + # Categorical + if hasattr(parameter, "categories"): + return np.array(parameter.categories) + + # Real or Integer + else: + return np.unique( + np.linspace(parameter.low, parameter.high, n_points).astype( + parameter.dtype + ) + ) + + def grid_search( + self, + n_points: Union[tuple, int] = 3, + error_metric: Union[Callable, str] = "log", + plot_reconstructed_objects: bool = True, + return_reconstructed_objects: bool = False, + **kwargs: dict, + ): + """ + Run optimizer + + Parameters + ---------- + n_initial_points: int + Number of uniformly spaced trial points to run on a grid + error_metric: Callable or str + Function used to compute the reconstruction error. + When passed as a string, may be one of: + 'log': log(NMSE) of final object + 'linear': NMSE of final object + 'log-converged': log(NMSE) of final object if + NMSE is decreasing, 0 if NMSE increasing + 'linear-converged': NMSE of final object if + NMSE is decreasing, 1 if NMSE increasing + 'TV': sum( abs( grad( object ) ) ) / sum( abs( object ) ) + 'std': negative standard deviation of cropped object + 'std-phase': negative standard deviation of + phase of the cropped object + 'entropy-phase': entropy of the phase of the + cropped object + When passed as a Callable, a function that takes the + PhaseReconstruction object as its only argument + and returns the error metric as a single float + + """ + + num_params = len(self._parameter_list) + + if isinstance(n_points, int): + n_points = [n_points] * num_params + elif len(n_points) != num_params: + raise ValueError() + + params_grid = [ + self._generate_inclusive_boundary_grid(param, n_pts) + for param, n_pts in zip(self._parameter_list, n_points) + ] + params_grid = list(product(*params_grid)) + num_evals = len(params_grid) + + error_metric = self._get_error_metric(error_metric) + pbar = tqdm(total=num_evals, desc="Searching parameters") + + def evaluation_callback(ptycho): + if plot_reconstructed_objects or return_reconstructed_objects: + pbar.update(1) + return ( + ptycho._return_projected_cropped_potential(), + error_metric(ptycho), + ) + else: + pbar.update(1) + error_metric(ptycho) + + grid_search_function = self._get_optimization_function( + self._reconstruction_type, + self._parameter_list, + self._init_static_args, + self._affine_static_args, + self._preprocess_static_args, + self._reconstruction_static_args, + self._init_optimize_args, + self._affine_optimize_args, + self._preprocess_optimize_args, + self._reconstruction_optimize_args, + evaluation_callback, + ) + + grid_search_res = list(map(grid_search_function, params_grid)) + pbar.close() + + if plot_reconstructed_objects: + if len(n_points) == 2: + nrows, ncols = n_points + else: + nrows = kwargs.pop("nrows", int(np.sqrt(num_evals))) + ncols = kwargs.pop("ncols", int(np.ceil(num_evals / nrows))) + if nrows * ncols < num_evals: + raise ValueError() + + spec = GridSpec( + ncols=ncols, + nrows=nrows, + hspace=0.15, + wspace=0.15, + ) + + sx, sy = grid_search_res[0][0].shape + + separator = kwargs.pop("separator", "\n") + cmap = kwargs.pop("cmap", "magma") + figsize = kwargs.pop("figsize", (2.5 * ncols, 3 / sy * sx * nrows)) + fig = plt.figure(figsize=figsize) + + for index, (params, res) in enumerate(zip(params_grid, grid_search_res)): + row_index, col_index = np.unravel_index(index, (nrows, ncols)) + + ax = fig.add_subplot(spec[row_index, col_index]) + ax.imshow(res[0], cmap=cmap) + + title_substrings = [ + f"{param.name}: {val}" + for param, val in zip(self._parameter_list, params) + ] + title_substrings.append(f"error: {res[1]:.3e}") + title = separator.join(title_substrings) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(title) + spec.tight_layout(fig) + + if return_reconstructed_objects: + return grid_search_res + else: + return grid_search_res + def optimize( self, n_calls: int = 50, @@ -144,7 +290,7 @@ def optimize( error_metric = self._get_error_metric(error_metric) - self._optimization_function = self._get_optimization_function( + optimization_function = self._get_optimization_function( self._reconstruction_type, self._parameter_list, self._init_static_args, @@ -166,22 +312,37 @@ def optimize( def callback(*args, **kwargs): pbar.update(1) - self._skopt_result = gp_minimize( - self._optimization_function, - self._parameter_list, - n_calls=n_calls, - n_initial_points=n_initial_points, - x0=self._x0, - callback=callback, - **skopt_kwargs, - ) + try: + self._skopt_result = gp_minimize( + optimization_function, + self._parameter_list, + n_calls=n_calls, + n_initial_points=n_initial_points, + x0=self._x0, + callback=callback, + **skopt_kwargs, + ) - print("Optimized parameters:") - for p, x in zip(self._parameter_list, self._skopt_result.x): - print(f"{p.name}: {x}") + # Remove the optimization result's reference to the function, as it potentially contains a + # copy of the ptycho object + del self._skopt_result["fun"] - # Finish the tqdm progressbar so subsequent things behave nicely - pbar.close() + # If using the GPU, free some cached stuff + if self._init_args.get("device", "cpu") == "gpu": + import cupy as cp + + cp.get_default_memory_pool().free_all_blocks() + cp.get_default_pinned_memory_pool().free_all_blocks() + from cupy.fft.config import get_plan_cache + + get_plan_cache().clear() + + print("Optimized parameters:") + for p, x in zip(self._parameter_list, self._skopt_result.x): + print(f"{p.name}: {x}") + finally: + # close the pbar gracefully on interrupt + pbar.close() return self @@ -498,6 +659,7 @@ def f(**kwargs): def _set_optimizer_defaults( self, verbose=False, + clear_fft_cache=False, plot_center_of_mass=False, plot_rotation=False, plot_probe_overlaps=False, @@ -509,6 +671,7 @@ def _set_optimizer_defaults( Set all of the verbose and plotting to False, allowing for user-overwrite. """ self._init_static_args["verbose"] = verbose + self._init_static_args["clear_fft_cache"] = clear_fft_cache self._preprocess_static_args["plot_center_of_mass"] = plot_center_of_mass self._preprocess_static_args["plot_rotation"] = plot_rotation diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/phase_base_class.py similarity index 59% rename from py4DSTEM/process/phase/iterative_base_class.py rename to py4DSTEM/process/phase/phase_base_class.py index a0ed485ba..36e0c598a 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -2,17 +2,18 @@ Module for reconstructing phase objects from 4DSTEM datasets using iterative methods. """ +import sys import warnings import matplotlib.pyplot as plt import numpy as np -from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex -from scipy.ndimage import rotate +from py4DSTEM.visualize import show_complex +from scipy.ndimage import zoom try: import cupy as cp + from cupy.fft.config import get_plan_cache except (ModuleNotFoundError, ImportError): cp = np @@ -20,12 +21,10 @@ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin -from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( - PtychographicConstraints, -) from py4DSTEM.process.phase.utils import ( AffineTransform, - generate_batches, + copy_to_device, + get_array_module, polar_aliases, ) from py4DSTEM.process.utils import ( @@ -34,7 +33,8 @@ get_shifted_ar, ) -warnings.simplefilter(action="always", category=UserWarning) +warnings.showwarning = lambda msg, *args, **kwargs: print(msg, file=sys.stderr) +warnings.simplefilter("always", UserWarning) class PhaseReconstruction(Custom): @@ -43,6 +43,94 @@ class PhaseReconstruction(Custom): Defines various common functions and properties for subclasses to inherit. """ + def set_device(self, device, clear_fft_cache): + """ + Sets calculation device. + + Parameters + ---------- + device: str + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if clear_fft_cache is not None: + self._clear_fft_cache = clear_fft_cache + + if device is None: + return self + + if device == "cpu": + import scipy + + self._xp = np + self._scipy = scipy + + elif device == "gpu": + from cupyx import scipy + + self._xp = cp + self._scipy = scipy + + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + self._device = device + + return self + + def set_storage(self, storage): + """ + Sets storage device. + + Parameters + ---------- + storage: str + Device arrays will be stored on. Must be 'cpu' or 'gpu' + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if storage == "cpu": + self._xp_storage = np + + elif storage == "gpu": + if self._xp is np: + raise ValueError("storage='gpu' and device='cpu' is not supported") + self._xp_storage = cp + + else: + raise ValueError(f"storage must be either 'cpu' or 'gpu', not {storage}") + + self._asnumpy = copy_to_device + self._storage = storage + + return self + + def clear_device_mem(self, device, clear_fft_cache): + """ """ + if device == "gpu": + if clear_fft_cache: + cache = get_plan_cache() + cache.clear() + + xp = self._xp + xp._default_memory_pool.free_all_blocks() + xp._default_pinned_memory_pool.free_all_blocks() + + def copy_attributes_to_device(self, attrs, device): + """Utility function to copy a set of attrs to device""" + for attr in attrs: + array = copy_to_device(getattr(self, attr), device) + setattr(self, attr, array) + def attach_datacube(self, datacube: DataCube): """ Attaches a datacube to a class initialized without one. @@ -60,7 +148,13 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self - def reinitialize_parameters(self, device: str = None, verbose: bool = None): + def reinitialize_parameters( + self, + device: str = None, + storage: str = None, + clear_fft_cache: bool = None, + verbose: bool = None, + ): """ Reinitializes common parameters. This is useful when loading a previously-saved reconstruction (which set device='cpu' and verbose=True for compatibility) , @@ -69,7 +163,11 @@ def reinitialize_parameters(self, device: str = None, verbose: bool = None): Parameters ---------- device: str, optional - If not None, imports and assigns appropriate device modules + If not None, assigns appropriate device modules + storage: str, optional + If not None, assigns appropriate storage modules + clear_fft_cache: bool, optional + If not None, sets the FFT caching parameter verbose: bool, optional If not None, sets the verbosity to verbose @@ -80,27 +178,10 @@ def reinitialize_parameters(self, device: str = None, verbose: bool = None): """ if device is not None: - if device == "cpu": - self._xp = np - self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter - - self._gaussian_filter = gaussian_filter - from scipy.special import erf - - self._erf = erf - elif device == "gpu": - self._xp = cp - self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter + self.set_device(device, clear_fft_cache) - self._gaussian_filter = gaussian_filter - from cupyx.scipy.special import erf - - self._erf = erf - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - self._device = device + if storage is not None: + self.set_storage(storage) if verbose is not None: self._verbose = verbose @@ -144,10 +225,11 @@ def _preprocess_datacube_and_vacuum_probe( datacube, diffraction_intensities_shape=None, reshaping_method="fourier", - probe_roi_shape=None, + padded_diffraction_intensities_shape=None, vacuum_probe_intensity=None, dp_mask=None, com_shifts=None, + com_measured=None, ): """ Datacube preprocessing step, to set the reciprocal- and real-space sampling. @@ -161,13 +243,10 @@ def _preprocess_datacube_and_vacuum_probe( Note this does not affect the maximum scattering wavevector (Qx*dkx,Qy*dky) = (Sx*dkx',Sy*dky'), and thus the real-space sampling stays fixed. - The real space sampling, (dx, dy), combined with the resampled diffraction_intensities_shape, - sets the real-space probe region of interest (ROI) extent (dx*Sx, dy*Sy). - Occasionally, one may also want to specify a larger probe ROI extent, e.g when the probe - does not comfortably fit without self-ovelap artifacts, or when the scan step sizes are much - smaller than the real-space sampling (dx,dy). This can be achieved by specifying a - probe_roi_shape, which is larger than diffraction_intensities_shape, which will result in - zero-padding of the diffraction intensities. + Additionally, one may wish to zero-pad the diffraction intensity data. Note this does not increase + the information or resolution, but might be beneficial in a limited number of cases, e.g. when the + scan step sizes are much smaller than the real-space sampling (dx,dy). This can be achieved by specifying + a padded_diffraction_intensities_shape which is larger than diffraction_intensities_shape. Parameters ---------- @@ -178,7 +257,7 @@ def _preprocess_datacube_and_vacuum_probe( If None, no resamping is performed reshaping method: str, optional Reshaping method to use, one of 'bin', 'bilinear' or 'fourier' (default) - probe_roi_shape, (int,int), optional + padded_diffraction_intensities_shape, (int,int), optional Padded diffraction intensities shape. If None, no padding is performed vacuum_probe_intensity, np.ndarray, optional @@ -193,11 +272,19 @@ def _preprocess_datacube_and_vacuum_probe( datacube: Datacube Resampled and Padded datacube """ + if com_shifts is not None: if np.isscalar(com_shifts[0]): com_shifts = ( - np.ones(self._datacube.Rshape) * com_shifts[0], - np.ones(self._datacube.Rshape) * com_shifts[1], + np.ones(datacube.Rshape) * com_shifts[0], + np.ones(datacube.Rshape) * com_shifts[1], + ) + + if com_measured is not None: + if np.isscalar(com_measured[0]): + com_measured = ( + np.ones(datacube.Rshape) * com_measured[0], + np.ones(datacube.Rshape) * com_measured[1], ) if diffraction_intensities_shape is not None: @@ -218,6 +305,12 @@ def _preprocess_datacube_and_vacuum_probe( com_shifts[1] * resampling_factor_x, ) + if com_measured is not None: + com_measured = ( + com_measured[0] * resampling_factor_x, + com_measured[1] * resampling_factor_x, + ) + if reshaping_method == "bin": bin_factor = int(1 / resampling_factor_x) if bin_factor < 1: @@ -227,32 +320,89 @@ def _preprocess_datacube_and_vacuum_probe( datacube = datacube.bin_Q(N=bin_factor) if vacuum_probe_intensity is not None: - vacuum_probe_intensity = vacuum_probe_intensity[ - ::bin_factor, ::bin_factor - ] + # crop edges if necessary + if Qx % bin_factor == 0: + vacuum_probe_intensity = vacuum_probe_intensity[ + : -(Qx % bin_factor), : + ] + if Qy % bin_factor == 0: + vacuum_probe_intensity = vacuum_probe_intensity[ + :, : -(Qy % bin_factor) + ] + + vacuum_probe_intensity = vacuum_probe_intensity.reshape( + Qx // bin_factor, bin_factor, Qy // bin_factor, bin_factor + ).sum(axis=(1, 3)) if dp_mask is not None: - dp_mask = dp_mask[::bin_factor, ::bin_factor] - else: + # crop edges if necessary + if Qx % bin_factor == 0: + dp_mask = dp_mask[: -(Qx % bin_factor), :] + if Qy % bin_factor == 0: + dp_mask = dp_mask[:, : -(Qy % bin_factor)] + + dp_mask = dp_mask.reshape( + Qx // bin_factor, bin_factor, Qy // bin_factor, bin_factor + ).sum(axis=(1, 3)) + + elif reshaping_method == "fourier": datacube = datacube.resample_Q( - N=resampling_factor_x, method=reshaping_method + N=resampling_factor_x, + method=reshaping_method, + conserve_array_sums=True, ) if vacuum_probe_intensity is not None: vacuum_probe_intensity = fourier_resample( vacuum_probe_intensity, output_size=diffraction_intensities_shape, force_nonnegative=True, + conserve_array_sums=True, ) if dp_mask is not None: dp_mask = fourier_resample( dp_mask, output_size=diffraction_intensities_shape, force_nonnegative=True, + conserve_array_sums=False, + ) + + elif reshaping_method == "bilinear": + datacube = datacube.resample_Q( + N=resampling_factor_x, + method=reshaping_method, + conserve_array_sums=True, + ) + if vacuum_probe_intensity is not None: + vacuum_probe_intensity = zoom( + vacuum_probe_intensity, + (resampling_factor_x, resampling_factor_x), + order=1, + mode="grid-wrap", + grid_mode=True, ) + vacuum_probe_intensity = ( + vacuum_probe_intensity / resampling_factor_x**2 + ) + if dp_mask is not None: + dp_mask = zoom( + dp_mask, + (resampling_factor_x, resampling_factor_x), + order=1, + mode="grid-wrap", + grid_mode=True, + ) + + else: + raise ValueError( + ( + "reshaping_method needs to be one of 'bilinear', 'fourier', or 'bin', " + f"not {reshaping_method}." + ) + ) - if probe_roi_shape is not None: + if padded_diffraction_intensities_shape is not None: Qx, Qy = datacube.shape[-2:] - Sx, Sy = probe_roi_shape - datacube = datacube.pad_Q(output_size=probe_roi_shape) + Sx, Sy = padded_diffraction_intensities_shape + datacube = datacube.pad_Q(output_size=padded_diffraction_intensities_shape) if vacuum_probe_intensity is not None or dp_mask is not None: pad_kx = Sx - Qx @@ -269,7 +419,7 @@ def _preprocess_datacube_and_vacuum_probe( if dp_mask is not None: dp_mask = np.pad(dp_mask, pad_width=(pad_kx, pad_ky), mode="constant") - return datacube, vacuum_probe_intensity, dp_mask, com_shifts + return datacube, vacuum_probe_intensity, dp_mask, com_shifts, com_measured def _extract_intensities_and_calibrations_from_datacube( self, @@ -328,10 +478,11 @@ def _extract_intensities_and_calibrations_from_datacube( If require_calibrations is False and calibrations are not set """ - # Copies intensities to device casting to float32 - xp = self._xp + # explicit read-only self attributes up-front + verbose = self._verbose + energy = self._energy - intensities = xp.asarray(datacube.data, dtype=xp.float32) + intensities = np.asarray(datacube.data, dtype=np.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -342,13 +493,13 @@ def _extract_intensities_and_calibrations_from_datacube( # Real-space if force_scan_sampling is not None: self._scan_sampling = (force_scan_sampling, force_scan_sampling) - self._scan_units = "A" + self._scan_units = ("A",) * 2 else: if real_space_units == "pixels": if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - if self._verbose: + if verbose: warnings.warn( ( "Iterative reconstruction will not be quantitative unless you specify " @@ -373,35 +524,36 @@ def _extract_intensities_and_calibrations_from_datacube( # Reciprocal-space if force_angular_sampling is not None or force_reciprocal_sampling is not None: - # there is no xor keyword in Python! - angular = force_angular_sampling is not None - reciprocal = force_reciprocal_sampling is not None - assert (angular and not reciprocal) or ( - not angular and reciprocal - ), "Only one of angular or reciprocal calibration can be forced!" + if ( + force_angular_sampling is not None + and force_reciprocal_sampling is not None + ): + raise ValueError( + "Only one of angular or reciprocal calibration can be forced." + ) # angular calibration specified - if angular: + if force_angular_sampling is not None: self._angular_sampling = (force_angular_sampling,) * 2 self._angular_units = ("mrad",) * 2 - if self._energy is not None: + if energy is not None: self._reciprocal_sampling = ( force_angular_sampling - / electron_wavelength_angstrom(self._energy) + / electron_wavelength_angstrom(energy) / 1e3, ) * 2 self._reciprocal_units = ("A^-1",) * 2 # reciprocal calibration specified - if reciprocal: + if force_reciprocal_sampling is not None: self._reciprocal_sampling = (force_reciprocal_sampling,) * 2 self._reciprocal_units = ("A^-1",) * 2 - if self._energy is not None: + if energy is not None: self._angular_sampling = ( force_reciprocal_sampling - * electron_wavelength_angstrom(self._energy) + * electron_wavelength_angstrom(energy) * 1e3, ) * 2 self._angular_units = ("mrad",) * 2 @@ -413,7 +565,7 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - if self._verbose: + if verbose: warnings.warn( ( "Iterative reconstruction will not be quantitative unless you specify " @@ -432,11 +584,9 @@ def _extract_intensities_and_calibrations_from_datacube( self._reciprocal_sampling = (reciprocal_size,) * 2 self._reciprocal_units = ("A^-1",) * 2 - if self._energy is not None: + if energy is not None: self._angular_sampling = ( - reciprocal_size - * electron_wavelength_angstrom(self._energy) - * 1e3, + reciprocal_size * electron_wavelength_angstrom(energy) * 1e3, ) * 2 self._angular_units = ("mrad",) * 2 @@ -445,9 +595,9 @@ def _extract_intensities_and_calibrations_from_datacube( self._angular_sampling = (angular_size,) * 2 self._angular_units = ("mrad",) * 2 - if self._energy is not None: + if energy is not None: self._reciprocal_sampling = ( - angular_size / electron_wavelength_angstrom(self._energy) / 1e3, + angular_size / electron_wavelength_angstrom(energy) / 1e3, ) * 2 self._reciprocal_units = ("A^-1",) * 2 else: @@ -467,6 +617,7 @@ def _calculate_intensities_center_of_mass( fit_function: str = "plane", com_shifts: np.ndarray = None, com_measured: np.ndarray = None, + vectorized_calculation=True, ): """ Common preprocessing function to compute and fit diffraction intensities CoM @@ -483,6 +634,8 @@ def _calculate_intensities_center_of_mass( If not None, com_shifts are fitted on the measured CoM values. com_measured: tuple of ndarrays (CoMx measured, CoMy measured) If not None, com_measured are passed as com_measured_x, com_measured_y + vectorized_calculation: bool, optional + If True (default), the calculation is vectorized Returns ------- @@ -500,20 +653,18 @@ def _calculate_intensities_center_of_mass( Normalized vertical center of mass gradient """ + # explicit read-only self attributes up-front xp = self._xp + device = self._device asnumpy = self._asnumpy - # for ptycho + reciprocal_sampling = self._reciprocal_sampling + if com_measured: - com_measured_x, com_measured_y = com_measured + com_measured_x = xp.asarray(com_measured[0], dtype=xp.float32) + com_measured_y = xp.asarray(com_measured[1], dtype=xp.float32) else: - # Coordinates - kx = xp.arange(intensities.shape[-2], dtype=xp.float32) - ky = xp.arange(intensities.shape[-1], dtype=xp.float32) - kya, kxa = xp.meshgrid(ky, kx) - - # calculate CoM if dp_mask is not None: if dp_mask.shape != intensities.shape[-2:]: raise ValueError( @@ -522,41 +673,87 @@ def _calculate_intensities_center_of_mass( f"not {dp_mask.shape}" ) ) - intensities_mask = intensities * xp.asarray(dp_mask, dtype=xp.float32) - else: - intensities_mask = intensities + dp_mask = xp.asarray(dp_mask, dtype=xp.float32) - intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) - com_measured_x = ( - xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) - / intensities_sum - ) - com_measured_y = ( - xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) - / intensities_sum - ) + # Coordinates + kx = xp.arange(intensities.shape[-2], dtype=xp.float32) + ky = xp.arange(intensities.shape[-1], dtype=xp.float32) + kya, kxa = xp.meshgrid(ky, kx) + + if vectorized_calculation: + # copy to device + intensities = copy_to_device(intensities, device) + + # calculate CoM + if dp_mask is not None: + intensities_mask = intensities * dp_mask + else: + intensities_mask = intensities + + intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) + + com_measured_x = ( + xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) + / intensities_sum + ) + com_measured_y = ( + xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) + / intensities_sum + ) + + else: + sx, sy = intensities.shape[:2] + com_measured_x = xp.zeros((sx, sy), dtype=xp.float32) + com_measured_y = xp.zeros((sx, sy), dtype=xp.float32) + + # loop of dps + for rx, ry in tqdmnd( + sx, + sy, + desc="Calculating center of mass", + unit="probe position", + disable=not self._verbose, + ): + masked_intensity = copy_to_device(intensities[rx, ry], device) + if dp_mask is not None: + masked_intensity *= dp_mask + summed_intensity = masked_intensity.sum() + com_measured_x[rx, ry] = ( + xp.sum(masked_intensity * kxa) / summed_intensity + ) + com_measured_y[rx, ry] = ( + xp.sum(masked_intensity * kya) / summed_intensity + ) if com_shifts is None: - com_measured_x_np = asnumpy(com_measured_x) - com_measured_y_np = asnumpy(com_measured_y) - finite_mask = np.isfinite(com_measured_x_np) - - com_shifts = fit_origin( - (com_measured_x_np, com_measured_y_np), - fitfunction=fit_function, - mask=finite_mask, - ) + if fit_function is not None: + com_measured_x_np = asnumpy(com_measured_x) + com_measured_y_np = asnumpy(com_measured_y) + finite_mask = np.isfinite(com_measured_x_np) + + com_shifts = fit_origin( + (com_measured_x_np, com_measured_y_np), + fitfunction=fit_function, + mask=finite_mask, + ) + + com_fitted_x = xp.asarray(com_shifts[0], dtype=xp.float32) + com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) + else: + com_fitted_x = xp.asarray(com_measured_x, dtype=xp.float32) + com_fitted_y = xp.asarray(com_measured_y, dtype=xp.float32) + else: + com_fitted_x = xp.asarray(com_shifts[0], dtype=xp.float32) + com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # Fit function to center of mass - com_fitted_x = xp.asarray(com_shifts[0], dtype=xp.float32) - com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # fix CoM units com_normalized_x = ( - xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + xp.nan_to_num(com_measured_x - com_fitted_x) * reciprocal_sampling[0] ) com_normalized_y = ( - xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + xp.nan_to_num(com_measured_y - com_fitted_y) * reciprocal_sampling[1] ) return ( @@ -574,7 +771,7 @@ def _solve_for_center_of_mass_relative_rotation( _com_measured_y: np.ndarray, _com_normalized_x: np.ndarray, _com_normalized_y: np.ndarray, - rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + rotation_angles_deg: np.ndarray = None, plot_rotation: bool = True, plot_center_of_mass: str = "default", maximize_divergence: bool = False, @@ -636,15 +833,22 @@ def _solve_for_center_of_mass_relative_rotation( Summary statistics """ + # explicit read-only self attributes up-front xp = self._xp asnumpy = self._asnumpy + verbose = self._verbose + scan_sampling = self._scan_sampling + scan_units = self._scan_units + + if rotation_angles_deg is None: + rotation_angles_deg = np.arange(-89.0, 90.0, 1.0) if force_com_rotation is not None: # Rotation known _rotation_best_rad = np.deg2rad(force_com_rotation) - if self._verbose: + if verbose: warnings.warn( ( "Best fit rotation forced to " @@ -658,7 +862,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = force_com_transpose - if self._verbose: + if verbose: warnings.warn( f"Transpose of intensities forced to {force_com_transpose}.", UserWarning, @@ -709,11 +913,12 @@ def _solve_for_center_of_mass_relative_rotation( else: _rotation_best_transpose = rotation_curl_transpose < rotation_curl - if self._verbose: + if verbose: if _rotation_best_transpose: - print("Diffraction intensities should be transposed.") - else: - print("No need to transpose diffraction intensities.") + warnings.warn( + "Diffraction intensities should be transposed.", + UserWarning, + ) else: # Rotation unknown @@ -722,7 +927,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = force_com_transpose - if self._verbose: + if verbose: warnings.warn( f"Transpose of intensities forced to {force_com_transpose}.", UserWarning, @@ -817,8 +1022,11 @@ def _solve_for_center_of_mass_relative_rotation( rotation_best_deg = rotation_angles_deg[ind_min] _rotation_best_rad = rotation_angles_rad[ind_min] - if self._verbose: - print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) + if verbose: + warnings.warn( + f"Best fit rotation = {rotation_best_deg:.0f} degrees.", + UserWarning, + ) if plot_rotation: figsize = kwargs.get("figsize", (8, 2)) @@ -827,17 +1035,21 @@ def _solve_for_center_of_mass_relative_rotation( if _rotation_best_transpose: ax.plot( rotation_angles_deg, - asnumpy(rotation_div_transpose) - if maximize_divergence - else asnumpy(rotation_curl_transpose), + ( + asnumpy(rotation_div_transpose) + if maximize_divergence + else asnumpy(rotation_curl_transpose) + ), label="CoM after transpose", ) else: ax.plot( rotation_angles_deg, - asnumpy(rotation_div) - if maximize_divergence - else asnumpy(rotation_curl), + ( + asnumpy(rotation_div) + if maximize_divergence + else asnumpy(rotation_curl) + ), label="CoM", ) @@ -959,8 +1171,6 @@ def _solve_for_center_of_mass_relative_rotation( # Minimize Curl ind_min = xp.argmin(rotation_curl).item() ind_trans_min = xp.argmin(rotation_curl_transpose).item() - self._rotation_curl = rotation_curl - self._rotation_curl_transpose = rotation_curl_transpose if rotation_curl[ind_min] <= rotation_curl_transpose[ind_trans_min]: rotation_best_deg = rotation_angles_deg[ind_min] _rotation_best_rad = rotation_angles_rad[ind_min] @@ -971,13 +1181,18 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose = True self._rotation_angles_deg = rotation_angles_deg + # Print summary - if self._verbose: - print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) + if verbose: + warnings.warn( + f"Best fit rotation = {rotation_best_deg:.0f} degrees.", + UserWarning, + ) if _rotation_best_transpose: - print("Diffraction intensities should be transposed.") - else: - print("No need to transpose diffraction intensities.") + warnings.warn( + "Diffraction intensities should be transposed.", + UserWarning, + ) # Plot Curl/Div rotation if plot_rotation: @@ -986,16 +1201,20 @@ def _solve_for_center_of_mass_relative_rotation( ax.plot( rotation_angles_deg, - asnumpy(rotation_div) - if maximize_divergence - else asnumpy(rotation_curl), + ( + asnumpy(rotation_div) + if maximize_divergence + else asnumpy(rotation_curl) + ), label="CoM", ) ax.plot( rotation_angles_deg, - asnumpy(rotation_div_transpose) - if maximize_divergence - else asnumpy(rotation_curl_transpose), + ( + asnumpy(rotation_div_transpose) + if maximize_divergence + else asnumpy(rotation_curl_transpose) + ), label="CoM after transpose", ) y_r = ax.get_ylim() @@ -1049,18 +1268,14 @@ def _solve_for_center_of_mass_relative_rotation( + xp.cos(_rotation_best_rad) * _com_normalized_y ) - # 'Public'-facing attributes as numpy arrays - com_x = asnumpy(_com_x) - com_y = asnumpy(_com_y) - # Optionally, plot CoM if plot_center_of_mass == "all": figsize = kwargs.pop("figsize", (8, 12)) cmap = kwargs.pop("cmap", "RdBu_r") extent = [ 0, - self._scan_sampling[1] * _com_measured_x.shape[1], - self._scan_sampling[0] * _com_measured_x.shape[0], + scan_sampling[1] * _com_measured_x.shape[1], + scan_sampling[0] * _com_measured_x.shape[0], 0, ] @@ -1074,8 +1289,8 @@ def _solve_for_center_of_mass_relative_rotation( _com_measured_y, _com_normalized_x, _com_normalized_y, - com_x, - com_y, + _com_x, + _com_y, ], [ "CoM_x", @@ -1087,18 +1302,18 @@ def _solve_for_center_of_mass_relative_rotation( ], ): ax.imshow(asnumpy(arr), extent=extent, cmap=cmap, **kwargs) - ax.set_ylabel(f"x [{self._scan_units[0]}]") - ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.set_ylabel(f"x [{scan_units[0]}]") + ax.set_xlabel(f"y [{scan_units[1]}]") ax.set_title(title) - elif plot_center_of_mass == "default": + elif plot_center_of_mass == "default" or plot_center_of_mass is True: figsize = kwargs.pop("figsize", (8, 4)) cmap = kwargs.pop("cmap", "RdBu_r") extent = [ 0, - self._scan_sampling[1] * com_x.shape[1], - self._scan_sampling[0] * com_x.shape[0], + scan_sampling[1] * _com_x.shape[1], + scan_sampling[0] * _com_x.shape[0], 0, ] @@ -1108,17 +1323,17 @@ def _solve_for_center_of_mass_relative_rotation( for ax, arr, title in zip( grid, [ - com_x, - com_y, + _com_x, + _com_y, ], [ "Corrected CoM_x", "Corrected CoM_y", ], ): - ax.imshow(arr, extent=extent, cmap=cmap, **kwargs) - ax.set_ylabel(f"x [{self._scan_units[0]}]") - ax.set_xlabel(f"y [{self._scan_units[1]}]") + ax.imshow(asnumpy(arr), extent=extent, cmap=cmap, **kwargs) + ax.set_ylabel(f"x [{scan_units[0]}]") + ax.set_xlabel(f"y [{scan_units[1]}]") ax.set_title(title) return ( @@ -1126,8 +1341,6 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_transpose, _com_x, _com_y, - com_x, - com_y, ) def _normalize_diffraction_intensities( @@ -1135,8 +1348,8 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, - crop_patterns, positions_mask, + crop_patterns, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1149,11 +1362,11 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering - positions_mask: np.ndarray, optional - Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1163,15 +1376,21 @@ def _normalize_diffraction_intensities( Mean intensity value """ - xp = self._xp + # explicit read-only self attributes up-front + asnumpy = self._asnumpy + mean_intensity = 0 - diffraction_intensities = self._asnumpy(diffraction_intensities) + diffraction_intensities = asnumpy(diffraction_intensities) + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + if positions_mask is not None: number_of_patterns = np.count_nonzero(positions_mask.ravel()) else: number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + # Aggressive cropping for when off-centered high scattering angle data was recorded if crop_patterns: crop_x = int( np.minimum( @@ -1202,44 +1421,43 @@ def _normalize_diffraction_intensities( crop_mask[-crop_w:, :crop_w] = True crop_mask[:crop_w:, -crop_w:] = True crop_mask[-crop_w:, -crop_w:] = True - self._crop_mask = crop_mask else: + crop_mask = None region_of_interest_shape = diffraction_intensities.shape[-2:] amplitudes = np.zeros( (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 ) - com_fitted_x = self._asnumpy(com_fitted_x) - com_fitted_y = self._asnumpy(com_fitted_y) - counter = 0 - for rx in range(diffraction_intensities.shape[0]): - for ry in range(diffraction_intensities.shape[1]): - if positions_mask is not None: - if not positions_mask[rx, ry]: - continue - intensities = get_shifted_ar( - diffraction_intensities[rx, ry], - -com_fitted_x[rx, ry], - -com_fitted_y[rx, ry], - bilinear=True, - device="cpu", - ) + for rx, ry in tqdmnd( + diffraction_intensities.shape[0], + diffraction_intensities.shape[1], + desc="Normalizing amplitudes", + unit="probe position", + disable=not self._verbose, + ): + if positions_mask is not None: + if not positions_mask[rx, ry]: + continue + intensities = get_shifted_ar( + diffraction_intensities[rx, ry], + -com_fitted_x[rx, ry], + -com_fitted_y[rx, ry], + bilinear=True, + device="cpu", + ) - if crop_patterns: - intensities = intensities[crop_mask].reshape( - region_of_interest_shape - ) + if crop_patterns: + intensities = intensities[crop_mask].reshape(region_of_interest_shape) - mean_intensity += np.sum(intensities) - amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) - counter += 1 + mean_intensity += np.sum(intensities) + amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) + counter += 1 - amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] - return amplitudes, mean_intensity + return amplitudes, mean_intensity, crop_mask def show_complex_CoM( self, @@ -1268,13 +1486,18 @@ def show_complex_CoM( default is scan sampling """ + # explicit read-only self attributes up-front + asnumpy = self._asnumpy + scan_sampling = self._scan_sampling + scan_units = self._scan_units + if com is None: - com = (self.com_x, self.com_y) + com = (self._com_x, self._com_y) if pixelsize is None: - pixelsize = self._scan_sampling[0] + pixelsize = scan_sampling[0] if pixelunits is None: - pixelunits = self._scan_units[0] + pixelunits = scan_units[0] figsize = kwargs.pop("figsize", (6, 6)) fig, ax = plt.subplots(figsize=figsize) @@ -1282,7 +1505,7 @@ def show_complex_CoM( complex_com = com[0] + 1j * com[1] show_complex( - complex_com, + asnumpy(complex_com), cbar=cbar, figax=(fig, ax), scalebar=scalebar, @@ -1293,10 +1516,10 @@ def show_complex_CoM( ) -class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): +class PtychographicReconstruction(PhaseReconstruction): """ Base ptychographic reconstruction class. - Inherits from PhaseReconstruction and PtychographicConstraints. + Inherits from PhaseReconstruction. Defines various common functions and properties for subclasses to inherit. """ @@ -1331,6 +1554,8 @@ def to_h5(self, group): "object_type": self._object_type, "verbose": self._verbose, "device": self._device, + "storage": self._storage, + "clear_fft_cache": self._clear_fft_cache, "name": self.name, "vacuum_probe_intensity": vacuum_probe_intensity, "positions": scan_positions, @@ -1373,15 +1598,7 @@ def to_h5(self, group): # reconstruction metadata is_stack = self._save_iterations and hasattr(self, "object_iterations") - if is_stack: - num_iterations = len(self.object_iterations) - iterations = list(range(0, num_iterations, self._save_iterations_frequency)) - if num_iterations - 1 not in iterations: - iterations.append(num_iterations - 1) - - error = [self.error_iterations[i] for i in iterations] - else: - error = getattr(self, "error", 0.0) + error = self.error_iterations self.metadata = Metadata( name="reconstruction_metadata", @@ -1406,6 +1623,8 @@ def to_h5(self, group): self._probe_emd = Array(name="reconstruction_probe", data=asnumpy(self._probe)) if is_stack: + num_iterations = len(self.object_iterations) + iterations = list(range(0, num_iterations, self._save_iterations_frequency)) iterations_labels = [f"iteration_{i:03}" for i in iterations] # object @@ -1485,6 +1704,8 @@ def _get_constructor_args(cls, group): "polar_parameters": polar_params, "verbose": True, # for compatibility "device": "cpu", # for compatibility + "storage": "cpu", # for compatibility + "clear_fft_cache": True, # for compatibility } class_specific_kwargs = {} @@ -1525,13 +1746,12 @@ def _populate_instance(self, group): self._exit_waves = None # Check if stack - if hasattr(error, "__len__"): + if "_object_iterations_emd" in dict_data.keys(): self.object_iterations = list(dict_data["_object_iterations_emd"].data) self.probe_iterations = list(dict_data["_probe_iterations_emd"].data) - self.error_iterations = error - self.error = error[-1] - else: - self.error = error + + self.error_iterations = error + self.error = error[-1] # Slim preprocessing to enable visualize self._positions_px_com = xp.mean(self._positions_px, axis=0) @@ -1539,6 +1759,29 @@ def _populate_instance(self, group): self.probe = self.probe_centered self._preprocessed = True + def _switch_object_type(self, object_type): + """ + Switches object type to/from "potential"/"complex" + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + xp = self._xp + + match (self._object_type, object_type): + case ("potential", "complex"): + self._object_type = "complex" + self._object = xp.exp(1j * self._object, dtype=xp.complex64) + case ("complex", "potential"): + self._object_type = "potential" + self._object = xp.angle(self._object) + case _: + self._object_type = self._object_type + + return self + def _set_polar_parameters(self, parameters: dict): """ Set the probe aberrations dictionary. @@ -1563,7 +1806,11 @@ def _set_polar_parameters(self, parameters: dict): raise ValueError("{} not a recognized parameter".format(symbol)) def _calculate_scan_positions_in_pixels( - self, positions: np.ndarray, positions_mask + self, + positions: np.ndarray, + positions_mask, + object_padding_px, + positions_offset_ang, ): """ Method to compute the initial guess of scan positions in pixels. @@ -1575,16 +1822,27 @@ def _calculate_scan_positions_in_pixels( If None, a raster scan using experimental parameters is constructed. positions_mask: np.ndarray, optional Boolean real space mask to select positions in datacube to skip for reconstruction + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + positions_offset_ang, np.ndarray, optional + Offset of positions in A Returns ------- positions_in_px: (J,2) np.ndarray Initial guess of scan positions in pixels + object_padding_px: Tupe[int,int] + Updated object_padding_px """ + # explicit read-only self attributes up-front grid_scan_shape = self._grid_scan_shape rotation_angle = self._rotation_best_rad + transpose = self._rotation_best_transpose step_sizes = self._scan_sampling + region_of_interest_shape = self._region_of_interest_shape + sampling = self.sampling if positions is None: if grid_scan_shape is not None: @@ -1599,47 +1857,59 @@ def _calculate_scan_positions_in_pixels( else: raise ValueError() - if self._rotation_best_transpose: - x = (x - np.ptp(x) / 2) / self.sampling[1] - y = (y - np.ptp(y) / 2) / self.sampling[0] + if transpose: + x = (x - np.ptp(x) / 2) / sampling[1] + y = (y - np.ptp(y) / 2) / sampling[0] else: - x = (x - np.ptp(x) / 2) / self.sampling[0] - y = (y - np.ptp(y) / 2) / self.sampling[1] + x = (x - np.ptp(x) / 2) / sampling[0] + y = (y - np.ptp(y) / 2) / sampling[1] x, y = np.meshgrid(x, y, indexing="ij") + + if positions_offset_ang is not None: + if transpose: + x += positions_offset_ang[0] / sampling[1] + y += positions_offset_ang[1] / sampling[0] + else: + x += positions_offset_ang[0] / sampling[0] + y += positions_offset_ang[1] / sampling[1] + if positions_mask is not None: x = x[positions_mask] y = y[positions_mask] else: positions -= np.mean(positions, axis=0) - x = positions[:, 0] / self.sampling[1] - y = positions[:, 1] / self.sampling[0] + x = positions[:, 0] / sampling[1] + y = positions[:, 1] / sampling[0] if rotation_angle is not None: x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin( rotation_angle ) + y * np.cos(rotation_angle) - if self._rotation_best_transpose: + if transpose: positions = np.array([y.ravel(), x.ravel()]).T else: positions = np.array([x.ravel(), y.ravel()]).T + positions -= np.min(positions, axis=0) - if self._object_padding_px is None: - float_padding = self._region_of_interest_shape / 2 - self._object_padding_px = (float_padding, float_padding) - elif np.isscalar(self._object_padding_px[0]): - self._object_padding_px = ( - (self._object_padding_px[0],) * 2, - (self._object_padding_px[1],) * 2, + if object_padding_px is None: + float_padding = region_of_interest_shape / 2 + object_padding_px = (float_padding, float_padding) + elif np.isscalar(object_padding_px[0]): + object_padding_px = ( + (object_padding_px[0],) * 2, + (object_padding_px[1],) * 2, ) - positions[:, 0] += self._object_padding_px[0][0] - positions[:, 1] += self._object_padding_px[1][0] + positions[:, 0] += object_padding_px[0][0] + positions[:, 1] += object_padding_px[1][0] - return positions + return positions, object_padding_px - def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): + def _sum_overlapping_patches_bincounts_base( + self, patches: np.ndarray, positions_px + ): """ Base bincouts overlapping patches sum function, operating on real-valued arrays. Note this assumes the probe is corner-centered. @@ -1654,28 +1924,28 @@ def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): out_array: (Px,Py) np.ndarray Summed array """ - xp = self._xp - x0 = xp.round(self._positions_px[:, 0]).astype("int") - y0 = xp.round(self._positions_px[:, 1]).astype("int") - + # explicit read-only self attributes up-front + xp = get_array_module(patches) roi_shape = self._region_of_interest_shape + object_shape = self._object_shape + + x0 = xp.round(positions_px[:, 0]).astype("int") + y0 = xp.round(positions_px[:, 1]).astype("int") + x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") flat_weights = patches.ravel() - indices = ( - (y0[:, None, None] + y_ind[None, None, :]) % self._object_shape[1] - ) + ( - (x0[:, None, None] + x_ind[None, :, None]) % self._object_shape[0] - ) * self._object_shape[ - 1 - ] + indices = ((y0[:, None, None] + y_ind[None, None, :]) % object_shape[1]) + ( + (x0[:, None, None] + x_ind[None, :, None]) % object_shape[0] + ) * object_shape[1] counts = xp.bincount( - indices.ravel(), weights=flat_weights, minlength=np.prod(self._object_shape) + indices.ravel(), weights=flat_weights, minlength=np.prod(object_shape) ) - return xp.reshape(counts, self._object_shape) + counts = xp.reshape(counts, object_shape).astype(xp.float32) + return counts - def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): + def _sum_overlapping_patches_bincounts(self, patches: np.ndarray, positions_px): """ Sum overlapping patches defined into object shaped array using bincounts. Calls _sum_overlapping_patches_bincounts_base on Real and Imaginary parts. @@ -1691,15 +1961,21 @@ def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): Summed array """ - xp = self._xp - if xp.iscomplexobj(patches): - real = self._sum_overlapping_patches_bincounts_base(xp.real(patches)) - imag = self._sum_overlapping_patches_bincounts_base(xp.imag(patches)) + if np.iscomplexobj(patches): + real = self._sum_overlapping_patches_bincounts_base( + patches.real, positions_px + ) + imag = self._sum_overlapping_patches_bincounts_base( + patches.imag, positions_px + ) return real + 1.0j * imag else: - return self._sum_overlapping_patches_bincounts_base(patches) + return self._sum_overlapping_patches_bincounts_base(patches, positions_px) - def _extract_vectorized_patch_indices(self): + def _extract_vectorized_patch_indices( + self, + positions_px, + ): """ Sets the vectorized row/col indices used for the overlap projection Note this assumes the probe is corner-centered. @@ -1711,15 +1987,17 @@ def _extract_vectorized_patch_indices(self): self._vectorized_patch_indices_col: np.ndarray Column indices for probe patches inside object array """ - xp = self._xp - x0 = xp.round(self._positions_px[:, 0]).astype("int") - y0 = xp.round(self._positions_px[:, 1]).astype("int") - + # explicit read-only self attributes up-front + xp_storage = self._xp_storage roi_shape = self._region_of_interest_shape - x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") - y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") - obj_shape = self._object_shape + + x0 = xp_storage.round(positions_px[:, 0]).astype("int") + y0 = xp_storage.round(positions_px[:, 1]).astype("int") + + x_ind = xp_storage.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") + y_ind = xp_storage.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") + vectorized_patch_indices_row = ( x0[:, None, None] + x_ind[None, :, None] ) % obj_shape[0] @@ -1729,903 +2007,197 @@ def _extract_vectorized_patch_indices(self): return vectorized_patch_indices_row, vectorized_patch_indices_col - def _crop_rotate_object_fov( + def _set_reconstruction_method_parameters( self, - array, - padding=0, + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, ): - """ - Crops and rotated object to FOV bounded by current pixel positions. - - Parameters - ---------- - array: np.ndarray - Object array to crop and rotate. Only operates on numpy arrays for comptatibility. - padding: int, optional - Optional padding outside pixel positions + """""" - Returns - cropped_rotated_array: np.ndarray - Cropped and rotated object array - """ + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) - asnumpy = self._asnumpy - angle = ( - self._rotation_best_rad - if self._rotation_best_transpose - else -self._rotation_best_rad - ) + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) - tf = AffineTransform(angle=angle) - rotated_points = tf( - asnumpy(self._positions_px), origin=asnumpy(self._positions_px_com), xp=np + return ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, ) - min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int") - min_x = min_x if min_x > 0 else 0 - min_y = min_y if min_y > 0 else 0 - max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int") - - rotated_array = rotate( - asnumpy(array), np.rad2deg(-angle), reshape=False, axes=(-2, -1) - )[..., min_x:max_x, min_y:max_y] - - if self._rotation_best_transpose: - rotated_array = rotated_array.swapaxes(-2, -1) - - return rotated_array - - def tune_angle_and_defocus( + def _report_reconstruction_summary( self, - angle_guess=None, - defocus_guess=None, - transpose=None, - angle_step_size=1, - defocus_step_size=20, - num_angle_values=5, - num_defocus_values=5, - max_iter=5, - plot_reconstructions=True, - plot_convergence=True, - return_values=False, - **kwargs, + max_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, ): - """ - Run reconstructions over a parameters space of angles and - defocus values. Should be run after preprocess step. - - Parameters - ---------- - angle_guess: float (degrees), optional - initial starting guess for rotation angle between real and reciprocal space - if None, uses current initialized values - defocus_guess: float (A), optional - initial starting guess for defocus - if None, uses current initialized values - angle_step_size: float (degrees), optional - size of change of rotation angle between real and reciprocal space for - each step in parameter space - defocus_step_size: float (A), optional - size of change of defocus for each step in parameter space - num_angle_values: int, optional - number of values of angle to test, must be >= 1. - num_defocus_values: int,optional - number of values of defocus to test, must be >= 1 - max_iter: int, optional - number of iterations to run in ptychographic reconstruction - plot_reconstructions: bool, optional - if True, plot phase of reconstructed objects - plot_convergence: bool, optional - if True, plots error for each iteration for each reconstruction. - return_values: bool, optional - if True, returns objects, convergence + """ """ - Returns - ------- - objects: list - reconstructed objects - convergence: np.ndarray - array of convergence values from reconstructions - """ - # calculate angles and defocus values to test - if angle_guess is None: - angle_guess = self._rotation_best_rad * 180 / np.pi - if defocus_guess is None: - defocus_guess = -self._polar_parameters["C10"] - if transpose is None: - transpose = self._rotation_best_transpose - - if num_angle_values == 1: - angle_step_size = 0 - - if num_defocus_values == 1: - defocus_step_size = 0 - - angles = np.linspace( - angle_guess - angle_step_size * (num_angle_values - 1) / 2, - angle_guess + angle_step_size * (num_angle_values - 1) / 2, - num_angle_values, - ) + # object type + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " - defocus_values = np.linspace( - defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2, - defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2, - num_defocus_values, - ) - - if return_values: - convergence = [] - objects = [] - - # current initialized values - current_verbose = self._verbose - current_defocus = -self._polar_parameters["C10"] - current_rotation_deg = self._rotation_best_rad * 180 / np.pi - current_transpose = self._rotation_best_transpose - - # Gridspec to plot on - if plot_reconstructions: - if plot_convergence: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values * 2, - height_ratios=[1, 1 / 4] * num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 5 * num_angle_values) + # stochastic gradient descent + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) ) else: - spec = GridSpec( - ncols=num_defocus_values, - nrows=num_angle_values, - hspace=0.15, - wspace=0.35, - ) - figsize = kwargs.get( - "figsize", (4 * num_defocus_values, 4 * num_angle_values) + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ), + UserWarning, ) - fig = plt.figure(figsize=figsize) + else: + # named projection set method + if reconstruction_parameter is not None: + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ), + UserWarning, + ) - progress_bar = kwargs.pop("progress_bar", False) - # run loop and plot along the way - self._verbose = False - for flat_index, (angle, defocus) in enumerate( - tqdmnd(angles, defocus_values, desc="Tuning angle and defocus") - ): - self._polar_parameters["C10"] = -defocus - self._probe = None - self._object = None - self.preprocess( - force_com_rotation=angle, - force_com_transpose=transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - - self.reconstruct( - reset=True, - store_iterations=True, - max_iter=max_iter, - progress_bar=progress_bar, - **kwargs, - ) - - if plot_reconstructions: - row_index, col_index = np.unravel_index( - flat_index, (num_angle_values, num_defocus_values) + # generalized projections (or the even more rare charge-flipping) + elif projection_a is not None: + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): " + f"{projection_a, projection_b, projection_c}." + ), + UserWarning, ) - if plot_convergence: - object_ax = fig.add_subplot(spec[row_index * 2, col_index]) - convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=convergence_ax, - cbar=True, - ) - convergence_ax.yaxis.tick_right() - else: - object_ax = fig.add_subplot(spec[row_index, col_index]) - self._visualize_last_iteration_figax( - fig, - object_ax=object_ax, - convergence_ax=None, - cbar=True, - ) - - object_ax.set_title( - f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self.error:.3e}" + # gradient descent + else: + warnings.warn( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ), + UserWarning, ) - object_ax.set_xticks([]) - object_ax.set_yticks([]) - - if return_values: - objects.append(self.object) - convergence.append(self.error_iterations.copy()) - - # initialize back to pre-tuning values - self._polar_parameters["C10"] = -current_defocus - self._probe = None - self._object = None - self.preprocess( - force_com_rotation=current_rotation_deg, - force_com_transpose=current_transpose, - plot_center_of_mass=False, - plot_rotation=False, - plot_probe_overlaps=False, - ) - self._verbose = current_verbose - if plot_reconstructions: - spec.tight_layout(fig) - - if return_values: - return objects, convergence - - def _position_correction( + def _constraints( self, - relevant_object, - relevant_probes, - relevant_overlap, - relevant_amplitudes, + current_object, + current_probe, current_positions, - positions_step_size, - constrain_position_distance, - ): - """ - Position correction using estimated intensity gradient. - - Parameters - -------- - relevant_object: np.ndarray - Current object estimate - relevant_probes:np.ndarray - fractionally-shifted probes - relevant_overlap: np.ndarray - object * probe overlap - relevant_amplitudes: np.ndarray - Measured amplitudes - current_positions: np.ndarray - Current positions estimate - positions_step_size: float - Positions step size - constrain_position_distance: float - Distance to constrain position correction within original - field of view in A - - Returns - -------- - updated_positions: np.ndarray - Updated positions estimate - """ - - xp = self._xp - - if self._object_type == "potential": - complex_object = xp.exp(1j * relevant_object) - else: - complex_object = relevant_object - - obj_rolled_x_patches = complex_object[ - (self._vectorized_patch_indices_row + 1) % self._object_shape[0], - self._vectorized_patch_indices_col, - ] - obj_rolled_y_patches = complex_object[ - self._vectorized_patch_indices_row, - (self._vectorized_patch_indices_col + 1) % self._object_shape[1], - ] - - overlap_fft = xp.fft.fft2(relevant_overlap) - - exit_waves_dx_fft = overlap_fft - xp.fft.fft2( - obj_rolled_x_patches * relevant_probes - ) - exit_waves_dy_fft = overlap_fft - xp.fft.fft2( - obj_rolled_y_patches * relevant_probes - ) - - overlap_fft_conj = xp.conj(overlap_fft) - estimated_intensity = xp.abs(overlap_fft) ** 2 - measured_intensity = relevant_amplitudes**2 - - flat_shape = (relevant_overlap.shape[0], -1) - difference_intensity = (measured_intensity - estimated_intensity).reshape( - flat_shape - ) - - partial_intensity_dx = 2 * xp.real( - exit_waves_dx_fft * overlap_fft_conj - ).reshape(flat_shape) - partial_intensity_dy = 2 * xp.real( - exit_waves_dy_fft * overlap_fft_conj - ).reshape(flat_shape) - - coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) - - # positions_update = xp.einsum( - # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity - # ) - - coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) - positions_update = ( - xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) - @ coefficients_matrix_T - @ difference_intensity[..., None] - ) - - if constrain_position_distance is not None: - constrain_position_distance /= xp.sqrt( - self.sampling[0] ** 2 + self.sampling[1] ** 2 - ) - x1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 0 - ] - y1 = (current_positions - positions_step_size * positions_update[..., 0])[ - :, 1 - ] - x0 = self._positions_px_initial[:, 0] - y0 = self._positions_px_initial[:, 1] - if self._rotation_best_transpose: - x0, y0 = xp.array([y0, x0]) - x1, y1 = xp.array([y1, x1]) - - if self._rotation_best_rad is not None: - rotation_angle = self._rotation_best_rad - x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( - -rotation_angle - ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) - x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( - -rotation_angle - ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) - - outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( - x1 < (xp.min(x0) - constrain_position_distance) - ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( - y1 < (xp.min(y0) - constrain_position_distance) - ) > 0 - - positions_update[..., 0][outlier_ind] = 0 - - current_positions -= positions_step_size * positions_update[..., 0] - return current_positions - - def plot_position_correction( - self, - scale_arrows=1, - plot_arrow_freq=1, - verbose=True, - **kwargs, - ): - """ - Function to plot changes to probe positions during ptychography reconstruciton - - Parameters - ---------- - scale_arrows: float, optional - scaling factor to be applied on vectors prior to plt.quiver call - verbose: bool, optional - if True, prints AffineTransformation if positions have been updated - """ - if verbose: - if hasattr(self, "_tf"): - print(self._tf) - - asnumpy = self._asnumpy - - extent = [ - 0, - self.sampling[1] * self._object_shape[1], - self.sampling[0] * self._object_shape[0], - 0, - ] - - initial_pos = asnumpy(self._positions_initial) - pos = self.positions - - figsize = kwargs.pop("figsize", (6, 6)) - color = kwargs.pop("color", (1, 0, 0, 1)) - - fig, ax = plt.subplots(figsize=figsize) - ax.quiver( - initial_pos[::plot_arrow_freq, 1], - initial_pos[::plot_arrow_freq, 0], - (pos[::plot_arrow_freq, 1] - initial_pos[::plot_arrow_freq, 1]) - * scale_arrows, - (pos[::plot_arrow_freq, 0] - initial_pos[::plot_arrow_freq, 0]) - * scale_arrows, - scale_units="xy", - scale=1, - color=color, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_xlim((extent[0], extent[1])) - ax.set_ylim((extent[2], extent[3])) - ax.set_aspect("equal") - ax.set_title("Probe positions correction") - - def _return_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - ): - """ - Returns complex fourier probe shifted to center of array from - corner-centered complex real space probe - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - - Returns - ------- - fourier_probe: np.ndarray - Fourier-transformed and center-shifted probe. - """ - xp = self._xp - - if probe is None: - probe = self._probe - else: - probe = xp.asarray(probe, dtype=xp.complex64) - - fourier_probe = xp.fft.fft2(probe) - - if remove_initial_probe_aberrations: - fourier_probe *= xp.conjugate(self._known_aberrations_array) - - return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) - - def _return_fourier_probe_from_centered_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - ): - """ - Returns complex fourier probe shifted to center of array from - centered complex real space probe - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - - Returns - ------- - fourier_probe: np.ndarray - Fourier-transformed and center-shifted probe. - """ - xp = self._xp - return self._return_fourier_probe( - xp.fft.ifftshift(probe, axes=(-2, -1)), - remove_initial_probe_aberrations=remove_initial_probe_aberrations, - ) - - def _return_centered_probe( - self, - probe=None, - ): - """ - Returns complex probe centered in middle of the array. - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses self._probe - - Returns - ------- - centered_probe: np.ndarray - Center-shifted probe. - """ - xp = self._xp - - if probe is None: - probe = self._probe - else: - probe = xp.asarray(probe, dtype=xp.complex64) - - return xp.fft.fftshift(probe, axes=(-2, -1)) - - def _return_object_fft( - self, - obj=None, - ): - """ - Returns absolute value of obj fft shifted to center of array - - Parameters - ---------- - obj: array, optional - if None is specified, uses self._object - - Returns - ------- - object_fft_amplitude: np.ndarray - Amplitude of Fourier-transformed and center-shifted obj. - """ - asnumpy = self._asnumpy - - if obj is None: - obj = self._object - - obj = self._crop_rotate_object_fov(asnumpy(obj)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - - def _return_self_consistency_errors( - self, - max_batch_size=None, - ): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Batch-size - if max_batch_size is None: - max_batch_size = self._num_diffraction_patterns - - # Re-initialize fractional positions and vector patches - errors = np.array([]) - positions_px = self._positions_px.copy() - - for start, end in generate_batches( - self._num_diffraction_patterns, max_batch=max_batch_size - ): - # batch indices - self._positions_px = positions_px[start:end] - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - amplitudes = self._amplitudes[start:end] - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - - # Normalized mean-squared errors - batch_errors = xp.sum( - xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - errors = np.hstack((errors, batch_errors)) - - self._positions_px = positions_px.copy() - errors /= self._mean_diffraction_intensity - - return asnumpy(errors) - - def _return_projected_cropped_potential( - self, - ): - """Utility function to accommodate multiple classes""" - if self._object_type == "complex": - projected_cropped_potential = np.angle(self.object_cropped) - else: - projected_cropped_potential = self.object_cropped - - return projected_cropped_potential - - def show_uncertainty_visualization( - self, - errors=None, - max_batch_size=None, - projected_cropped_potential=None, - kde_sigma=None, - plot_histogram=True, - plot_contours=False, + initial_positions, **kwargs, ): - """Plot uncertainty visualization using self-consistency errors""" - - if errors is None: - errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) - - if projected_cropped_potential is None: - projected_cropped_potential = self._return_projected_cropped_potential() - - if kde_sigma is None: - kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] - - xp = self._xp - asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter - - ## Kernel Density Estimation - - # rotated basis - angle = ( - self._rotation_best_rad - if self._rotation_best_transpose - else -self._rotation_best_rad - ) - - tf = AffineTransform(angle=angle) - rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) - - padding = xp.min(rotated_points, axis=0).astype("int") - - # bilinear sampling - pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( - 2 * padding - ) - pixel_size = pixel_output.prod() - - xa = rotated_points[:, 0] - ya = rotated_points[:, 1] - - # bilinear sampling - xF = xp.floor(xa).astype("int") - yF = xp.floor(ya).astype("int") - dx = xa - xF - dy = ya - yF - - # resampling - inds_1D = xp.ravel_multi_index( - xp.hstack( - [ - [xF, yF], - [xF + 1, yF], - [xF, yF + 1], - [xF + 1, yF + 1], - ] - ), - pixel_output, - mode=["wrap", "wrap"], - ) - - weights = xp.hstack( - ( - (1 - dx) * (1 - dy), - (dx) * (1 - dy), - (1 - dx) * (dy), - (dx) * (dy), - ) - ) - - pix_count = xp.reshape( - xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output - ) - - pix_output = xp.reshape( - xp.bincount( - inds_1D, - weights=weights * xp.tile(xp.asarray(errors), 4), - minlength=pixel_size, - ), - pixel_output, - ) - - # kernel density estimate - pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") - pix_count[pix_count == 0.0] = np.inf - pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") - pix_output /= pix_count - pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] - pix_output, _, _ = return_scaled_histogram_ordering( - pix_output.get(), normalize=True - ) - - ## Visualization - if plot_histogram: - spec = GridSpec( - ncols=1, - nrows=2, - height_ratios=[1, 4], - hspace=0.15, - ) - auto_figsize = (4, 5.25) - else: - spec = GridSpec( - ncols=1, - nrows=1, - ) - auto_figsize = (4, 4) - - figsize = kwargs.pop("figsize", auto_figsize) - - fig = plt.figure(figsize=figsize) - - if plot_histogram: - ax_hist = fig.add_subplot(spec[0]) - - counts, bins = np.histogram(errors, bins=50) - ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) - ax_hist.set_ylabel("Counts") - ax_hist.set_xlabel("Normalized Squared Error") - - ax = fig.add_subplot(spec[-1]) + """Wrapper function for all classes to inherit""" - cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - - projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( - projected_cropped_potential, - vmin=vmin, - vmax=vmax, + current_object = self._object_constraints(current_object, **kwargs) + current_probe = self._probe_constraints(current_probe, **kwargs) + current_positions = self._positions_constraints( + current_positions, initial_positions, **kwargs ) - extent = [ - 0, - self.sampling[1] * projected_cropped_potential.shape[1], - self.sampling[0] * projected_cropped_potential.shape[0], - 0, - ] - - ax.imshow( - projected_cropped_potential, - vmin=vmin, - vmax=vmax, - extent=extent, - alpha=1 - pix_output, - cmap=cmap, - **kwargs, - ) - - if plot_contours: - aligned_points = asnumpy(rotated_points - padding) - aligned_points[:, 0] *= self.sampling[0] - aligned_points[:, 1] *= self.sampling[1] - - ax.tricontour( - aligned_points[:, 1], - aligned_points[:, 0], - errors, - colors="grey", - levels=5, - # linestyles='dashed', - linewidths=0.5, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_xlim((extent[0], extent[1])) - ax.set_ylim((extent[2], extent[3])) - ax.xaxis.set_ticks_position("bottom") - - spec.tight_layout(fig) - - def show_fourier_probe( - self, - probe=None, - remove_initial_probe_aberrations=False, - cbar=True, - scalebar=True, - pixelsize=None, - pixelunits=None, - **kwargs, - ): - """ - Plot probe in fourier space - - Parameters - ---------- - probe: complex array, optional - if None is specified, uses the `probe_fourier` property - remove_initial_probe_aberrations: bool, optional - If True, removes initial probe aberrations from Fourier probe - cbar: bool, optional - if True, adds colorbar - scalebar: bool, optional - if True, adds scalebar to probe - pixelunits: str, optional - units for scalebar, default is A^-1 - pixelsize: float, optional - default is probe reciprocal sampling - """ - asnumpy = self._asnumpy - - probe = asnumpy( - self._return_fourier_probe( - probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations - ) - ) - - if pixelsize is None: - pixelsize = self._reciprocal_sampling[1] - if pixelunits is None: - pixelunits = r"$\AA^{-1}$" - - figsize = kwargs.pop("figsize", (6, 6)) - chroma_boost = kwargs.pop("chroma_boost", 1) - - fig, ax = plt.subplots(figsize=figsize) - show_complex( - probe, - cbar=cbar, - figax=(fig, ax), - scalebar=scalebar, - pixelsize=pixelsize, - pixelunits=pixelunits, - ticks=False, - chroma_boost=chroma_boost, - **kwargs, - ) - - def show_object_fft(self, obj=None, **kwargs): - """ - Plot FFT of reconstructed object - - Parameters - ---------- - obj: complex array, optional - if None is specified, uses the `object_fft` property - """ - if obj is None: - object_fft = self.object_fft - else: - object_fft = self._return_object_fft(obj) - - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) - show( - object_fft, - figsize=figsize, - cmap=cmap, - scalebar=True, - pixelsize=pixelsize, - ticks=False, - pixelunits=r"$\AA^{-1}$", - **kwargs, - ) - - @property - def probe_fourier(self): - """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy(self._return_fourier_probe(self._probe)) - - @property - def probe_fourier_residual(self): - """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy( - self._return_fourier_probe( - self._probe, remove_initial_probe_aberrations=True - ) - ) - - @property - def probe_centered(self): - """Current probe estimate shifted to the center""" - if not hasattr(self, "_probe"): - return None - - asnumpy = self._asnumpy - return asnumpy(self._return_centered_probe(self._probe)) - - @property - def object_fft(self): - """Fourier transform of current object estimate""" - - if not hasattr(self, "_object"): - return None - - return self._return_object_fft(self._object) + return current_object, current_probe, current_positions @property def angular_sampling(self): @@ -2641,7 +2213,7 @@ def sampling(self): return tuple( electron_wavelength_angstrom(self._energy) * 1e3 / dk / n - for dk, n in zip(self.angular_sampling, self._region_of_interest_shape) + for dk, n in zip(self.angular_sampling, self._amplitudes_shape) ) @property @@ -2658,9 +2230,3 @@ def positions(self): positions[:, 1] *= self.sampling[1] return asnumpy(positions) - - @property - def object_cropped(self): - """Cropped and rotated object""" - - return self._crop_rotate_object_fov(self._object) diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py new file mode 100644 index 000000000..6da679067 --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -0,0 +1,1377 @@ +import warnings + +import numpy as np +from py4DSTEM.process.phase.utils import ( + array_slice, + estimate_global_transformation_ransac, + fft_shift, + fit_aberration_surface, + regularize_probe_amplitude, +) +from py4DSTEM.process.utils import get_CoM + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + import os + + # make sure pylops doesn't try to use cupy + os.environ["CUPY_PYLOPS"] = "0" +import pylops # this must follow the exception + + +class ObjectNDConstraintsMixin: + """ + Mixin class for object constraints applicable to 2D,2.5D, and 3D objects. + """ + + def _object_threshold_constraint(self, current_object, pure_phase_object): + """ + Ptychographic threshold constraint. + Used for avoiding the scaling ambiguity between probe and object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + pure_phase_object: bool + If True, object amplitude is set to unity + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + if self._object_type == "complex": + phase = xp.angle(current_object) + + if pure_phase_object: + amplitude = 1.0 + else: + amplitude = xp.minimum(xp.abs(current_object), 1.0) + + return amplitude * xp.exp(1.0j * phase) + else: + return current_object + + def _object_shrinkage_constraint(self, current_object, shrinkage_rad, object_mask): + """ + Ptychographic shrinkage constraint. + Used to ensure electrostatic potential is positive. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + if self._object_type == "complex": + phase = xp.angle(current_object) + amp = xp.abs(current_object) + + if object_mask is not None: + shrinkage_rad += phase[..., object_mask].mean() + + phase -= shrinkage_rad + + current_object = amp * xp.exp(1.0j * phase) + else: + if object_mask is not None: + shrinkage_rad += current_object[..., object_mask].mean() + + current_object -= shrinkage_rad + + return current_object + + def _object_positivity_constraint(self, current_object): + """ + Ptychographic positivity constraint. + Used to ensure potential is positive. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + if self._object_type == "complex": + return current_object + else: + return current_object.clip(0.0) + + def _object_gaussian_constraint( + self, current_object, gaussian_filter_sigma, pure_phase_object + ): + """ + Ptychographic smoothness constraint. + Used for blurring object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + gaussian_filter_sigma: float + Standard deviation of gaussian kernel in A + pure_phase_object: bool + If True, gaussian blur performed on phase only + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + gaussian_filter = self._scipy.ndimage.gaussian_filter + gaussian_filter_sigma /= self.sampling[0] + + if not pure_phase_object or self._object_type == "potential": + current_object = gaussian_filter(current_object, gaussian_filter_sigma) + else: + phase = xp.angle(current_object) + phase = gaussian_filter(phase, gaussian_filter_sigma) + current_object = xp.exp(1.0j * phase) + + return current_object + + def _object_butterworth_constraint( + self, + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ): + """ + Ptychographic butterworth filter. + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[-2], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[-1], self.sampling[1]) + + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object, axis=(-2, -1), keepdims=True) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + + def _object_denoise_tv_chambolle( + self, + current_object, + weight, + axis, + padding, + eps=2.0e-4, + max_num_iter=200, + scaling=None, + ): + """ + Perform total-variation denoising on n-dimensional images. + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float, optional + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + axis: int or tuple + Axis for denoising, if None uses all axes + pad_object: bool + if True, pads object with zeros along axes of blurring + eps : float, optional + Relative difference of the value of the cost function that determines + the stop criterion. The algorithm stops when: + + (E_(n-1) - E_n) < eps * E_0 + + max_num_iter : int, optional + Maximal number of iterations used for the optimization. + scaling : tuple, optional + Scale weight of tv denoise on different axes + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + Notes + ----- + Rudin, Osher and Fatemi algorithm. + Adapted skimage.restoration.denoise_tv_chambolle. + """ + xp = self._xp + + if self._object_type == "complex": + updated_object = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + current_object_sum = xp.sum(current_object) + + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if padding is not None: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (padding, padding) + + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, + ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() + + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E + E_previous = E + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 + + if padding is not None: + for ax in range(len(ndim)): + slices = array_slice( + ndim[ax], current_object.ndim, padding, -padding + ) + updated_object = updated_object[slices] + + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) + + return updated_object + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """ObjectNDConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + + +class Object2p5DConstraintsMixin: + """ + Mixin class for object constraints unique to 2.5D objects. + Overwrites ObjectNDConstraintsMixin. + """ + + def _object_denoise_tv_pylops(self, current_object, weights, iterations, z_padding): + """ + Performs second order TV denoising along x and y, and first order along z + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z_weight, r_weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + z_padding: int + Symmetric padding around the first axis + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((z_padding, z_padding), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[ + z_padding:-z_padding + ] + + return current_object_tv + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma, z_padding + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + z_padding: int + Symmetric padding around the first axis + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + # zero pad at top and bottom slice + pad_width = ((z_padding, z_padding), (0, 0), (0, 0)) + current_object = xp.pad(current_object, pad_width=pad_width, mode="constant") + + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[z_padding:-z_padding] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + pure_phase_object, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + identical_slices, + kz_regularization_filter, + kz_regularization_gamma, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object2p5DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, + kz_regularization_gamma, + z_padding=1, + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + z_padding=1, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + padding=tv_denoise_pad_chambolle, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # amplitude threshold (complex) or positivity (potential) + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + + +class Object3DConstraintsMixin: + """ + Mixin class for object constraints unique to 3D objects. + Overwrites ObjectNDConstraintsMixin and Object2p5DConstraintsMixin. + """ + + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + if self._object_type == "complex": + current_object_tv = current_object + warnings.warn( + ( + "TV denoising is currently only supported for object_type=='potential'." + ), + UserWarning, + ) + + else: + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + xyz_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(0, 1, 2), edge=False, kind="backward" + ) + + l1_regs = [xyz_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + Butterworth filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qra = xp.sqrt(qza**2 + qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_constraints( + self, + current_object, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + butterworth_order, + q_lowpass, + q_highpass, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + object_positivity, + shrinkage_rad, + object_mask, + **kwargs, + ): + """Object3DConstraints wrapper function""" + + # smoothness + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object=False + ) + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + # L1-norm pushing vacuum to zero + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + # Positivity + if object_positivity: + current_object = self._object_positivity_constraint(current_object) + + return current_object + + +class ProbeConstraintsMixin: + """ + Mixin class for regularizations applicable to a single probe. + """ + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + + probe_intensity = xp.abs(current_probe) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_amplitude_constraint( + self, current_probe, relative_radius, relative_width + ): + """ + Ptychographic top-hat filtering of probe. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + erf = self._scipy.special.erf + + probe_intensity = xp.abs(current_probe) ** 2 + current_probe_sum = xp.sum(probe_intensity) + + X = xp.fft.fftfreq(current_probe.shape[0])[:, None] + Y = xp.fft.fftfreq(current_probe.shape[1])[None] + r = xp.hypot(X, Y) - relative_radius + + sigma = np.sqrt(np.pi) / relative_width + tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) + + updated_probe = current_probe * tophat_mask + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + + return updated_probe * normalization + + def _probe_fourier_amplitude_constraint( + self, + current_probe, + width_max_pixels, + enforce_constant_intensity, + ): + """ + Ptychographic top-hat filtering of Fourier probe. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + threshold: np.ndarray + Threshold value for current probe fourier mask. Value should + be between 0 and 1, where 1 uses the maximum amplitude to threshold. + relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + asnumpy = self._asnumpy + + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_fft = xp.fft.fft2(current_probe) + + updated_probe_fft, _, _, _ = regularize_probe_amplitude( + asnumpy(current_probe_fft), + width_max_pixels=width_max_pixels, + nearest_angular_neighbor_averaging=5, + enforce_constant_intensity=enforce_constant_intensity, + corner_centered=True, + ) + + updated_probe_fft = xp.asarray(updated_probe_fft) + updated_probe = xp.fft.ifft2(updated_probe_fft) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + + return updated_probe * normalization + + def _probe_aperture_constraint( + self, + current_probe, + initial_probe_aperture, + ): + """ + Ptychographic constraint to fix Fourier amplitude to initial aperture. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) + + updated_probe = xp.fft.ifft2( + xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture + ) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + + return updated_probe * normalization + + def _probe_aberration_fitting_constraint( + self, + current_probe, + max_angular_order, + max_radial_order, + remove_initial_probe_aberrations, + use_scikit_image, + ): + """ + Ptychographic probe smoothing constraint. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate + max_angular_order: bool + Max angular order of probe aberrations basis functions + max_radial_order: bool + Max radial order of probe aberrations basis functions + remove_initial_probe_aberrations: bool, optional + If true, initial probe aberrations are removed before fitting + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + + xp = self._xp + + fourier_probe = xp.fft.fft2(current_probe) + if remove_initial_probe_aberrations: + fourier_probe *= xp.conj(self._known_aberrations_array) + + fourier_probe_abs = xp.abs(fourier_probe) + sampling = self.sampling + energy = self._energy + + fitted_angle, _ = fit_aberration_surface( + fourier_probe, + sampling, + energy, + max_angular_order, + max_radial_order, + use_scikit_image, + xp=xp, + ) + + fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) + if remove_initial_probe_aberrations: + fourier_probe *= self._known_aberrations_array + + current_probe = xp.fft.ifft2(fourier_probe) + + return current_probe + + def _probe_constraints( + self, + current_probe, + fix_probe_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + **kwargs, + ): + """ProbeConstraints wrapper function""" + + # CoM corner-centering + if fix_probe_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( + current_probe, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( + current_probe, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( + current_probe, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + return current_probe + + +class ProbeMixedConstraintsMixin: + """ + Mixin class for regularizations unique to mixed probes. + Overwrites ProbeConstraintsMixin. + """ + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _probe_constraints( + self, + current_probe, + fix_probe_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + num_probes_fit_aberrations, + fix_probe_aperture, + initial_probe_aperture, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + orthogonalize_probe, + **kwargs, + ): + """ProbeMixedConstraints wrapper function""" + + # CoM corner-centering + if fix_probe_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # Fourier phase (aberrations) fitting + if fit_probe_aberrations: + if num_probes_fit_aberrations > self._num_probes: + num_probes_fit_aberrations = self._num_probes + for probe_idx in range(num_probes_fit_aberrations): + current_probe[probe_idx] = self._probe_aberration_fitting_constraint( + current_probe[probe_idx], + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image, + ) + + # Fourier amplitude (aperture) constraints + if fix_probe_aperture: + current_probe[0] = self._probe_aperture_constraint( + current_probe[0], + initial_probe_aperture[0], + ) + elif constrain_probe_fourier_amplitude: + current_probe[0] = self._probe_fourier_amplitude_constraint( + current_probe[0], + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + ) + + # Real-space amplitude constraint + if constrain_probe_amplitude: + for probe_idx in range(self._num_probes): + current_probe[probe_idx] = self._probe_amplitude_constraint( + current_probe[probe_idx], + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + ) + + # Probe orthogonalization + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + return current_probe + + +class PositionsConstraintsMixin: + """ + Mixin class for probe positions constraints. + """ + + def _positions_center_of_mass_constraint( + self, current_positions, initial_positions_com + ): + """ + Ptychographic position center of mass constraint. + Additionally updates vectorized indices used in _overlap_projection. + + Parameters + ---------- + current_positions: np.ndarray + Current positions estimate + + Returns + -------- + constrained_positions: np.ndarray + CoM constrained positions estimate + """ + current_positions -= current_positions.mean(0) - initial_positions_com + + return current_positions + + def _positions_affine_transformation_constraint( + self, initial_positions, current_positions + ): + """ + Constrains the updated positions to be an affine transformation of the initial scan positions, + composing of two scale factors, a shear, and a rotation angle. + + Uses RANSAC to estimate the global transformation robustly. + Stores the AffineTransformation in self._tf. + + Parameters + ---------- + initial_positions: np.ndarray + Initial scan positions + current_positions: np.ndarray + Current positions estimate + + Returns + ------- + constrained_positions: np.ndarray + Affine-transform constrained positions estimate + """ + + xp_storage = self._xp_storage + initial_positions_com = initial_positions.mean(0) + + tf, _ = estimate_global_transformation_ransac( + positions0=initial_positions, + positions1=current_positions, + origin=initial_positions_com, + translation_allowed=True, + min_sample=initial_positions.shape[0] // 10, + xp=xp_storage, + ) + + current_positions = tf( + initial_positions, origin=initial_positions_com, xp=xp_storage + ) + self._tf = tf + + return current_positions + + def _positions_constraints( + self, + current_positions, + initial_positions, + fix_positions, + fix_positions_com, + global_affine_transformation, + **kwargs, + ): + """PositionsConstraints wrapper function""" + + if not fix_positions: + if not fix_positions_com: + current_positions = self._positions_center_of_mass_constraint( + current_positions, initial_positions.mean(0) + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + initial_positions, current_positions + ) + + return current_positions diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py new file mode 100644 index 000000000..fa0b1db9f --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -0,0 +1,3525 @@ +import warnings +from typing import Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from emdfile import tqdmnd +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.process.phase.utils import ( + AffineTransform, + ComplexProbe, + bilinear_resample, + copy_to_device, + fft_shift, + generate_batches, + partition_list, + rotate_point, + spatial_frequencies, +) +from py4DSTEM.process.utils import ( + align_and_shift_images, + electron_wavelength_angstrom, + get_CoM, + get_shifted_ar, +) +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex +from scipy.ndimage import gaussian_filter, rotate + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +class ObjectNDMethodsMixin: + """ + Mixin class for object methods applicable to 2D,2.5D, and 3D objects. + """ + + def _initialize_object( + self, + initial_object, + positions_px, + object_type, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_type == "potential": + _object = xp.zeros((p, q), dtype=xp.float32) + elif object_type == "complex": + _object = xp.ones((p, q), dtype=xp.complex64) + else: + if object_type == "potential": + _object = xp.asarray(initial_object, dtype=xp.float32) + elif object_type == "complex": + _object = xp.asarray(initial_object, dtype=xp.complex64) + + return _object + + def _crop_rotate_object_fov( + self, + array, + positions_px=None, + padding=0, + ): + """ + Crops and rotated object to FOV bounded by current pixel positions. + + Parameters + ---------- + array: np.ndarray + Object array to crop and rotate. Only operates on numpy arrays for compatibility. + padding: int, optional + Optional padding outside pixel positions + + Returns + cropped_rotated_array: np.ndarray + Cropped and rotated object array + """ + + asnumpy = self._asnumpy + + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + if positions_px is None: + positions_px = asnumpy(self._positions_px) + else: + positions_px = asnumpy(positions_px) + + tf = AffineTransform(angle=angle) + rotated_points = tf(positions_px, origin=positions_px.mean(0), xp=np) + + min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int") + min_x = min_x if min_x > 0 else 0 + min_y = min_y if min_y > 0 else 0 + max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int") + + rotated_array = rotate( + asnumpy(array), np.rad2deg(-angle), order=1, reshape=False, axes=(-2, -1) + )[..., min_x:max_x, min_y:max_y] + + if self._rotation_best_transpose: + rotated_array = rotated_array.swapaxes(-2, -1) + + return rotated_array + + def _return_projected_cropped_potential( + self, + obj=None, + return_kwargs=False, + **kwargs, + ): + """Utility function to accommodate multiple classes""" + if obj is None: + obj = self.object_cropped + else: + obj = self._crop_rotate_object_fov(obj) + + if np.iscomplexobj(obj): + obj = np.angle(obj) + + if return_kwargs: + return obj, kwargs + else: + return obj + + def _return_object_fft( + self, + obj=None, + apply_hanning_window=False, + **kwargs, + ): + """ + Returns absolute value of obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. + """ + xp = self._xp + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + if np.iscomplexobj(obj): + obj = xp.angle(obj) + + obj = self._crop_rotate_object_fov(asnumpy(obj)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + def show_object_fft( + self, + obj=None, + apply_hanning_window=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot FFT of reconstructed object + + Parameters + ---------- + obj: complex array, optional + If None is specified, uses the `object_fft` property + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is object FFT sampling + """ + + object_fft = self._return_object_fft( + obj, apply_hanning_window=apply_hanning_window, **kwargs + ) + + if pixelsize is None: + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + figsize = kwargs.pop("figsize", (4, 4)) + cmap = kwargs.pop("cmap", "magma") + ticks = kwargs.pop("ticks", False) + vmin = kwargs.pop("vmin", 0.001) + vmax = kwargs.pop("vmax", 0.999) + + # remove additional 3D FFT parameters before passing to show + kwargs.pop("orientation_matrix", None) + kwargs.pop("vertical_lims", None) + kwargs.pop("horizontal_lims", None) + + show( + object_fft, + figsize=figsize, + cmap=cmap, + scalebar=scalebar, + pixelsize=pixelsize, + ticks=ticks, + pixelunits=pixelunits, + vmin=vmin, + vmax=vmax, + aspect=object_fft.shape[1] / object_fft.shape[0], + **kwargs, + ) + + def _reset_reconstruction( + self, + store_iterations, + reset, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._object_type = self._object_type_initial + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + self._exit_waves = None + + @property + def object_fft(self): + """Fourier transform of current object estimate""" + + if not hasattr(self, "_object"): + return None + + return self._return_object_fft(self._object) + + @property + def object_cropped(self): + """Cropped and rotated object""" + + return self._crop_rotate_object_fov(self._object) + + +class Object2p5DMethodsMixin: + """ + Mixin class for object methods unique to 2.5D objects. + Overwrites ObjectNDMethodsMixin. + """ + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + theta_x: float = None, + theta_y: float = None, + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + theta_x: float, optional + x tilt of propagator in mrad + theta_y: float, optional + y tilt of propagator in mrad + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)) + + if theta_x is not None: + propagators[i] *= xp.exp( + 1.0j * (-2 * kx[:, None] * np.pi * dz * np.tan(theta_x / 1e3)) + ) + + if theta_y is not None: + propagators[i] *= xp.exp( + 1.0j * (-2 * ky[None] * np.pi * dz * np.tan(theta_y / 1e3)) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def _initialize_object( + self, + initial_object, + num_slices, + positions_px, + object_type, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + if object_type == "potential": + _object = xp.zeros((num_slices, p, q), dtype=xp.float32) + elif object_type == "complex": + _object = xp.ones((num_slices, p, q), dtype=xp.complex64) + else: + if object_type == "potential": + _object = xp.asarray(initial_object, dtype=xp.float32) + elif object_type == "complex": + _object = xp.asarray(initial_object, dtype=xp.complex64) + + return _object + + def _return_projected_cropped_potential( + self, + obj=None, + return_kwargs=False, + **kwargs, + ): + """Utility function to accommodate multiple classes""" + + if obj is None: + obj = self.object_cropped + else: + obj = self._crop_rotate_object_fov(obj) + + if np.iscomplexobj(obj): + obj = np.angle(obj).sum(0) + else: + obj = obj.sum(0) + + if return_kwargs: + return obj, kwargs + else: + return obj + + def _return_object_fft( + self, + obj=None, + apply_hanning_window=False, + **kwargs, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. + """ + xp = self._xp + + if obj is None: + obj = self._object + + if np.iscomplexobj(obj): + obj = xp.angle(obj) + + obj = self._crop_rotate_object_fov(obj.sum(axis=0)) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + def show_depth_section( + self, + ptA: Tuple[float, float], + ptB: Tuple[float, float], + aspect_ratio: float = "auto", + plot_line_profile: bool = False, + ms_object=None, + specify_calibrated: bool = True, + gaussian_filter_sigma: float = None, + cbar: bool = True, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + ---------- + ptA: Tuple[float,float] + Starting point (x1,y1) for line profile depth section + If either is None, assumed to be array start. + Specified in Angstroms unless specify_calibrated is False + ptB: Tuple[float,float] + End point (x2,y2) for line profile depth section + If either is None, assumed to be array end. + Specified in Angstroms unless specify_calibrated is False + aspect_ratio: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + ms_object: np.array + Object to plot slices of. If None, uses current object + specify_calibrated: bool (optional) + If False, ptA and ptB points specified in pixels instead of Angstroms + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + cbar: bool, optional + If True, displays a colorbar + """ + if ms_object is None: + ms_object = self.object_cropped + + if np.iscomplexobj(ms_object): + ms_object = np.angle(ms_object) + + x1, y1 = ptA + x2, y2 = ptB + + if x1 is None: + x1 = 0 + if y1 is None: + y1 = 0 + if x2 is None: + x2 = self.sampling[0] * ms_object.shape[1] + if y2 is None: + y2 = self.sampling[1] * ms_object.shape[2] + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + x1, x2 = np.array([x1, x2]).clip(0, ms_object.shape[1]) + y1, y2 = np.array([y1, y2]).clip(0, ms_object.shape[2]) + + angle = np.arctan2(x2 - x1, y2 - y1) + + x0 = ms_object.shape[1] / 2 + y0 = ms_object.shape[2] / 2 + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_object, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + -int(x1_0), + axis=1, + ) + + if gaussian_filter_sigma is not None: + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + y1_0, y2_0 = ( + np.array([y1_0, y2_0]).astype("int").clip(0, rotated_object.shape[2]) + ) + plot_im = rotated_object[:, 0, y1_0:y2_0] + + # Plotting + if plot_line_profile: + ncols = 2 + else: + ncols = 1 + col_index = 0 + + spec = GridSpec(ncols=ncols, nrows=1, wspace=0.15) + + figsize = kwargs.pop("figsize", (4 * ncols, 4)) + fig = plt.figure(figsize=figsize) + cmap = kwargs.pop("cmap", "magma") + + # Line profile + if plot_line_profile: + ax = fig.add_subplot(spec[0, col_index]) + + extent_line = [ + 0, + self.sampling[1] * ms_object.shape[2], + self.sampling[0] * ms_object.shape[1], + 0, + ] + + ax.imshow(ms_object.sum(0), cmap="gray", extent=extent_line) + + ax.plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + + ax.set_xlabel("y [A]") + ax.set_ylabel("x [A]") + ax.set_title("Multislice depth profile location") + col_index += 1 + + # Main visualization + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + ax = fig.add_subplot(spec[0, col_index]) + im = ax.imshow(plot_im, cmap=cmap, extent=extent) + + if aspect_ratio is not None: + if aspect_ratio == "auto": + aspect_ratio = extent[1] / extent[2] + if plot_line_profile: + aspect_ratio *= extent_line[2] / extent_line[1] + + ax.set_aspect(aspect_ratio) + cbar = False + + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + spec.tight_layout(fig) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + show_fft: bool = False, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) + + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + + if common_color_scale: + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + rotated_object, vmin, vmax = return_scaled_histogram_ordering( + rotated_object, vmin=vmin, vmax=vmax + ) + else: + vmin = None + vmax = None + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + +class Object3DMethodsMixin: + """ + Mixin class for object methods unique to 3D objects. + Overwrites ObjectNDMethodsMixin and Object2p5DMethodsMixin. + """ + + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) + + def _project_sliced_object(self, array: np.ndarray, output_z): + """ + Projects voxel-sliced object. + + Parameters + ---------- + array: np.ndarray + 3D array to project + output_z: int + Output_dimension to project array to. + + Returns + ------- + projected_array: np.ndarray + projected array + """ + xp = self._xp + input_z = array.shape[0] + + voxels_per_slice = np.ceil(input_z / output_z).astype("int") + pad_size = voxels_per_slice * output_z - input_z + + padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0))) + + return xp.sum( + padded_array.reshape( + ( + -1, + voxels_per_slice, + ) + + array.shape[1:] + ), + axis=1, + ) + + def _expand_sliced_object(self, array: np.ndarray, output_z): + """ + Expands supersliced object. + + Parameters + ---------- + array: np.ndarray + 3D array to expand + output_z: int + Output_dimension to expand array to. + + Returns + ------- + expanded_array: np.ndarray + expanded array + """ + xp = self._xp + input_z = array.shape[0] + + voxels_per_slice = np.ceil(output_z / input_z).astype("int") + remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z) + + voxels_in_slice = xp.repeat(voxels_per_slice, input_z) + voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice + + normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] + return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + order=3, + ): + """ """ + + xp = self._xp + affine_transform = self._scipy.ndimage.affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume, tf, offset=offset, order=order) + + return volume + + def _initialize_object( + self, + initial_object, + positions_px, + object_type, + main_tilt_axis="vertical", + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + object_padding_px = self._object_padding_px + region_of_interest_shape = self._region_of_interest_shape + + if initial_object is None: + pad_x = object_padding_px[0][1] + pad_y = object_padding_px[1][1] + p, q = np.round(np.max(positions_px, axis=0)) + p = np.max([np.round(p + pad_x), region_of_interest_shape[0]]).astype("int") + q = np.max([np.round(q + pad_y), region_of_interest_shape[1]]).astype("int") + + if main_tilt_axis == "vertical": + _object = xp.zeros((q, p, q), dtype=xp.float32) + elif main_tilt_axis == "horizontal": + _object = xp.zeros((p, p, q), dtype=xp.float32) + else: + _object = xp.zeros((max(p, q), p, q), dtype=xp.float32) + else: + _object = xp.asarray(initial_object, dtype=xp.float32) + + return _object + + def _return_projected_cropped_potential( + self, + obj=None, + return_kwargs=False, + **kwargs, + ): + """Utility function to accommodate multiple classes""" + + asnumpy = self._asnumpy + + rot_matrix = kwargs.pop("orientation_matrix", None) + v_lims = kwargs.pop("vertical_lims", (None, None)) + h_lims = kwargs.pop("horizontal_lims", (None, None)) + + if obj is None: + obj = self._object + + if rot_matrix is not None: + obj = self._rotate_zxy_volume( + obj, + rot_matrix=rot_matrix, + ) + + start_v, end_v = v_lims + start_h, end_h = h_lims + obj = asnumpy(obj.sum(0)[start_v:end_v, start_h:end_h]) + + if return_kwargs: + return obj, kwargs + else: + return obj + + def _return_object_fft( + self, + obj=None, + apply_hanning_window=False, + orientation_matrix=None, + vertical_lims: Tuple[int, int] = (None, None), + horizontal_lims: Tuple[int, int] = (None, None), + **kwargs, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + apply_hanning_window: bool, optional + If True, a 2D Hann window is applied to the object before FFT + orientation_matrix: np.ndarray, optional + orientation matrix to rotate zone-axis + vertical_lims: tuple(int,int), optional + min/max vertical indices + horizontal_lims: tuple(int,int), optional + min/max horizontal indices + + Returns + ------- + object_fft_amplitude: np.ndarray + Amplitude of Fourier-transformed and center-shifted obj. + """ + + xp = self._xp + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + else: + obj = xp.asarray(obj, dtype=xp.float32) + + if orientation_matrix is not None: + obj = self._rotate_zxy_volume( + obj, + rot_matrix=orientation_matrix, + ) + + start_v, end_v = vertical_lims + start_h, end_h = horizontal_lims + obj = asnumpy(obj.sum(0)[start_v:end_v, start_h:end_h]) + + if apply_hanning_window: + sx, sy = obj.shape + wx = np.hanning(sx) + wy = np.hanning(sy) + obj *= wx[:, None] * wy[None, :] + + return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + + @property + def object_supersliced(self): + """Returns super-sliced object""" + return self._project_sliced_object(self._object, self._num_slices) + + +class ProbeMethodsMixin: + """ + Mixin class for probe methods applicable to a single probe. + """ + + def _initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ): + """ """ + # explicit read-only self attributes up-front + xp = self._xp + device = self._device + + crop_mask = self._crop_mask + region_of_interest_shape = self._region_of_interest_shape + sampling = self.sampling + energy = self._energy + rolloff = self._rolloff + polar_parameters = self._polar_parameters + + if initial_probe is None: + if vacuum_probe_intensity is not None: + semiangle_cutoff = np.inf + vacuum_probe_intensity = xp.asarray( + vacuum_probe_intensity, dtype=xp.float32 + ) + + sx, sy = vacuum_probe_intensity.shape + tx, ty = region_of_interest_shape + if sx != tx or sy != ty: + vacuum_probe_intensity = bilinear_resample( + vacuum_probe_intensity, + output_size=(tx, ty), + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + probe_x0, probe_y0 = get_CoM( + vacuum_probe_intensity, + device=device, + ) + vacuum_probe_intensity = get_shifted_ar( + vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=device, + ) + + if crop_patterns: + vacuum_probe_intensity = vacuum_probe_intensity[crop_mask].reshape( + region_of_interest_shape + ) + + _probe = ( + ComplexProbe( + gpts=region_of_interest_shape, + sampling=sampling, + energy=energy, + semiangle_cutoff=semiangle_cutoff, + rolloff=rolloff, + vacuum_probe_intensity=vacuum_probe_intensity, + parameters=polar_parameters, + device=device, + ) + .build() + ._array + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(_probe)) ** 2) + _probe *= xp.sqrt(mean_diffraction_intensity / probe_intensity) + + else: + if isinstance(initial_probe, ComplexProbe): + if initial_probe._gpts != region_of_interest_shape: + raise ValueError() + if hasattr(initial_probe, "_array"): + _probe = initial_probe._array + else: + initial_probe._xp = xp + _probe = initial_probe.build()._array + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(_probe)) ** 2) + _probe *= xp.sqrt(mean_diffraction_intensity / probe_intensity) + else: + _probe = xp.asarray(initial_probe, dtype=xp.complex64) + + return _probe, semiangle_cutoff + + def _return_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + ): + """ + Returns complex fourier probe shifted to center of array from + corner-centered complex real space probe + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + + Returns + ------- + fourier_probe: np.ndarray + Fourier-transformed and center-shifted probe. + """ + xp = self._xp + + if probe is None: + probe = self._probe + else: + probe = xp.asarray(probe, dtype=xp.complex64) + + fourier_probe = xp.fft.fft2(probe) + + if remove_initial_probe_aberrations: + fourier_probe *= xp.conjugate(self._known_aberrations_array) + + return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) + + def _return_fourier_probe_from_centered_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + ): + """ + Returns complex fourier probe shifted to center of array from + centered complex real space probe + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + + Returns + ------- + fourier_probe: np.ndarray + Fourier-transformed and center-shifted probe. + """ + xp = self._xp + return self._return_fourier_probe( + xp.fft.ifftshift(probe, axes=(-2, -1)), + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + + def _return_centered_probe( + self, + probe=None, + ): + """ + Returns complex probe centered in middle of the array. + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + + Returns + ------- + centered_probe: np.ndarray + Center-shifted probe. + """ + xp = self._xp + + if probe is None: + probe = self._probe + else: + probe = xp.asarray(probe, dtype=xp.complex64) + + return xp.fft.fftshift(probe, axes=(-2, -1)) + + def _return_probe_intensities(self, probe): + """ + Returns probe intensities summing up to 1. + """ + if probe is None: + probe = self.probe_centered + + intensity_arrays = np.abs(np.array(probe, ndmin=3)) ** 2 + probe_ratio = list(intensity_arrays.sum((-2, -1)) / intensity_arrays.sum()) + + return probe_ratio + + def show_probe( + self, + probe=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + W=6, + **kwargs, + ): + """ + Plot probe in real space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelsize: float, optional + default is probe reciprocal sampling + pixelunits: str, optional + units for scalebar, default is A^-1 + W: int, optional + if not None, sets the width of the image grid + """ + asnumpy = self._asnumpy + + if pixelsize is None: + pixelsize = self.sampling[1] + if pixelunits is None: + pixelunits = r"$\AA$" + + intensities = self._return_probe_intensities(probe) + title = [ + f"Probe {iter} intensity: {ratio*100:.1f}%" + for iter, ratio in enumerate(intensities) + ] + + axsize = kwargs.pop("axsize", (4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + ticks = kwargs.pop("ticks", False) + title = kwargs.pop("title", title if len(title) > 1 else title[0]) + + if probe is None: + probe = list(np.array(self.probe_centered, ndmin=3)) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_centered_probe( + pr, + ) + ) + for pr in probe + ] + + probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] + + show_complex( + probe, + cbar=cbar, + axsize=axsize, + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=ticks, + chroma_boost=chroma_boost, + title=title, + **kwargs, + ) + + def show_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + W=6, + **kwargs, + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelsize: float, optional + default is probe reciprocal sampling + pixelunits: str, optional + units for scalebar, default is A^-1 + W: int, optional + if not None, sets the width of the image grid + """ + asnumpy = self._asnumpy + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + intensities = self._return_probe_intensities(probe) + title = [ + f"Probe {iter} intensity: {ratio*100:.1f}%" + for iter, ratio in enumerate(intensities) + ] + + axsize = kwargs.pop("axsize", (4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + ticks = kwargs.pop("ticks", False) + title = kwargs.pop("title", title if len(title) > 1 else title[0]) + + if probe is None: + if remove_initial_probe_aberrations: + probe = self.probe_fourier_residual + else: + probe = self.probe_fourier + probe = list(np.array(probe, ndmin=3)) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] + + probe = list(partition_list(probe, W)) + probe = probe if len(probe) > 1 else probe[0] + + show_complex( + probe, + cbar=cbar, + axsize=axsize, + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=ticks, + chroma_boost=chroma_boost, + title=title, + **kwargs, + ) + + def _return_single_probe(self, probe=None): + """Current probe estimate""" + xp = self._xp + + if probe is not None: + return xp.asarray(probe) + else: + if not hasattr(self, "_probe"): + return None + + return self._probe + + @property + def probe_fourier(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy(self._return_fourier_probe(self._probe)) + + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy( + self._return_fourier_probe( + self._probe, remove_initial_probe_aberrations=True + ) + ) + + @property + def probe_centered(self): + """Current probe estimate shifted to the center""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy(self._return_centered_probe(self._probe)) + + +class ProbeMixedMethodsMixin: + """ + Mixin class for probe methods unique to mixed probes. + Overwrites ProbeMethodsMixin. + """ + + def _initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ): + """ """ + + # explicit read-only self attributes up-front + xp = self._xp + num_probes = self._num_probes + region_of_interest_shape = self._region_of_interest_shape + + if initial_probe is None or isinstance(initial_probe, ComplexProbe): + # calls ProbeMethodsMixin for first probe + # annoyingly can't use super() as Mixins are defined right->left + # but MRO is defined left->right.. + _probe, semiangle_cutoff = ProbeMethodsMixin._initialize_probe( + self, + initial_probe, + vacuum_probe_intensity, + mean_diffraction_intensity, + semiangle_cutoff, + crop_patterns, + ) + + sx, sy = region_of_interest_shape + _probes = xp.zeros((num_probes, sx, sy), dtype=xp.complex64) + _probes[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + _probes[i_probe] = ( + _probes[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + else: + _probes = xp.asarray(initial_probe, dtype=xp.complex64) + + return _probes, semiangle_cutoff + + def _return_single_probe(self, probe=None): + """Current probe estimate""" + xp = self._xp + + if probe is not None: + return xp.asarray(probe[0]) + else: + if not hasattr(self, "_probe"): + return None + + return self._probe[0] + + +class ObjectNDProbeMethodsMixin: + """ + Mixin class for methods applicable to 2D, 2.5D, and 3D objects using a single probe. + """ + + def _return_shifted_probes(self, current_probe, positions_px_fractional): + """Simple utility to de-duplicate _overlap_projection""" + + xp = self._xp + shifted_probes = fft_shift(current_probe, positions_px_fractional, xp) + return shifted_probes + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = current_object[ + vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + overlap = shifted_probes * object_patches + + return shifted_probes, object_patches, overlap + + def _return_farfield_amplitudes(self, fourier_overlap): + """Small utility to de-duplicate mixed-state Fourier projection.""" + + xp = self._xp + return xp.abs(fourier_overlap) + + def cross_correlate_amplitudes_to_probe_aperture( + self, upsample_factor=4, progress_bar=True, probe=None + ): + """ + Cross-correlates the measured amplitudes with the current probe aperture. + Modifies self._amplitudes in-place. + + Parameters + ---------- + upsample_factor: float + Upsampling factor used in cross-correlation. Must be larger than 2 + probe: np.ndarray, optional + Probe to use for centering. Passed to _return_single_probe(probe) + + Returns + ------- + self to accommodate chaining + """ + xp = self._xp + storage = self._storage + + num_dps = self._num_diffraction_patterns + + single_probe = self._return_single_probe(probe) + probe_aperture = copy_to_device(xp.abs(xp.fft.fft2(single_probe)), storage) + + for idx in tqdmnd( + num_dps, + desc="Cross-correlating amplitudes", + unit="DP", + disable=not progress_bar, + ): + self._amplitudes[idx] = align_and_shift_images( + probe_aperture, + self._amplitudes[idx], + upsample_factor=upsample_factor, + device=storage, + ) + + return self + + def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + fourier_mask: np.ndarray + Mask to apply at the detector-plane for zeroing-out unreliable gradients + Useful when detector has artifacts such as dead-pixels + + Returns + -------- + exit_waves:np.ndarray + Difference between modified and estimated exit waves + error: float + Reconstruction error + """ + + xp = self._xp + + fourier_overlap = xp.fft.fft2(overlap) + + # resample to match data, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_overlap = bilinear_resample( + fourier_overlap, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + fourier_overlap *= fourier_mask + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) + + fourier_modified_overlap = ( + fourier_modified_overlap - fourier_overlap + ) * fourier_mask + + # resample back to region_of_interest_shape, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_modified_overlap = bilinear_resample( + fourier_modified_overlap, + output_size=self._region_of_interest_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + exit_waves = xp.fft.ifft2(fourier_modified_overlap) + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + overlap, + exit_waves, + fourier_mask, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + exit_waves: np.ndarray + previously estimated exit waves + fourier_mask: np.ndarray + Mask to apply at the detector-plane for zeroing-out unreliable gradients + Useful when detector has artifacts such as dead-pixels + Currently not implemented for projection-sets + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + if fourier_mask is not None: + raise NotImplementedError() + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = overlap.copy() + + factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + # resample to match data, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_projected_factor = bilinear_resample( + fourier_projected_factor, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + + fourier_projected_factor = amplitudes * xp.exp( + 1j * xp.angle(fourier_projected_factor) + ) + + # resample back to region_of_interest_shape, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_projected_factor = bilinear_resample( + fourier_projected_factor, + output_size=self._region_of_interest_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * overlap + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + current_probe, + positions_px_fractional, + amplitudes, + exit_waves, + fourier_mask, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + fourier_mask: np.ndarray + Mask to apply at the detector-plane for zeroing-out unreliable gradients + Useful when detector has artifacts such as dead-pixels + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + object * probe overlap + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + shifted_probes = self._return_shifted_probes( + current_probe, positions_px_fractional + ) + + shifted_probes, object_patches, overlap = self._overlap_projection( + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + overlap, + exit_waves, + fourier_mask, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, + overlap, + fourier_mask, + ) + + return shifted_probes, object_patches, overlap, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + object_update: np.ndarray + Updated object estimate + probe_update: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ), + positions_px, + ) + * probe_normalization + ) + else: + current_object += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves, positions_px + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += step_size * ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes) + * exit_waves + ), + positions_px, + ) + * probe_normalization + ) + else: + current_object = ( + self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes) * exit_waves, + positions_px, + ) + * probe_normalization + ) + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + xp.conj(object_patches) * exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes, + current_positions, + current_positions_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + shifted_probes:np.ndarray + fractionally-shifted probes + overlap: np.ndarray + object * probe overlap + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + max_position_update_distance: float + Maximum allowed distance for update in A + max_position_total_distance: float + Maximum allowed distance from initial probe positions + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + storage = self._storage + + overlap_fft = xp.fft.fft2(overlap) + + # resample to match data, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + overlap_fft = bilinear_resample( + overlap_fft, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + # unperturbed + overlap_fft_conj = xp.conj(overlap_fft) + + estimated_intensity = self._return_farfield_amplitudes(overlap_fft) ** 2 + measured_intensity = amplitudes**2 + + # book-keeping + flat_shape = (measured_intensity.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # dx overlap projection perturbation + _, _, overlap_dx = self._overlap_projection( + current_object, + (vectorized_patch_indices_row + 1) % self._object_shape[0], + vectorized_patch_indices_col, + shifted_probes, + ) + + # dy overlap projection perturbation + _, _, overlap_dy = self._overlap_projection( + current_object, + vectorized_patch_indices_row, + (vectorized_patch_indices_col + 1) % self._object_shape[1], + shifted_probes, + ) + + overlap_dx_fft = xp.fft.fft2(overlap_dx) + overlap_dy_fft = xp.fft.fft2(overlap_dy) + + # resample to match data, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + overlap_dx_fft = bilinear_resample( + overlap_dx_fft, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + overlap_dy_fft = bilinear_resample( + overlap_dy_fft, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + # partial intensities + overlap_dx_fft = overlap_fft - overlap_dx_fft + overlap_dy_fft = overlap_fft - overlap_dy_fft + partial_intensity_dx = 2 * xp.real(overlap_dx_fft * overlap_fft_conj) + partial_intensity_dy = 2 * xp.real(overlap_dy_fft * overlap_fft_conj) + + # handle mixed-state, is this correct? + if partial_intensity_dx.ndim == 4: + partial_intensity_dx = partial_intensity_dx.sum(1) + partial_intensity_dy = partial_intensity_dy.sum(1) + + partial_intensity_dx = partial_intensity_dx.reshape(flat_shape) + partial_intensity_dy = partial_intensity_dy.reshape(flat_shape) + + # least-squares fit + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + positions_update = positions_update[..., 0] * positions_step_size + + if max_position_update_distance is not None: + max_position_update_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + update_norms = xp.linalg.norm(positions_update, axis=1) + outlier_ind = update_norms > max_position_update_distance + positions_update[outlier_ind] /= ( + update_norms[outlier_ind, None] / max_position_update_distance + ) + + if max_position_total_distance is not None: + max_position_total_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + deltas = ( + xp.asarray(current_positions - current_positions_initial) + - positions_update + ) + dsts = xp.linalg.norm(deltas, axis=1) + outlier_ind = dsts > max_position_total_distance + positions_update[outlier_ind] = 0 + + current_positions -= copy_to_device(positions_update, storage) + + return current_positions + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + errors = np.array([]) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device(self._amplitudes[start:end], device) + + # Overlaps + shifted_probes = self._return_shifted_probes( + self._probe, positions_px_fractional + ) + _, _, overlap = self._overlap_projection( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) + + fourier_overlap = xp.fft.fft2(overlap) + + # resample to match data, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_overlap = bilinear_resample( + fourier_overlap, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes_device - farfield_amplitudes) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + +class Object2p5DProbeMethodsMixin: + """ + Mixin class for methods unique to 2.5D objects using a single probe. + Overwrites ObjectNDProbeMethodsMixin. + """ + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes_in, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + object_patches = current_object[ + :, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + shifted_probes = xp.empty_like(object_patches) + shifted_probes[0] = shifted_probes_in + + for s in range(self._num_slices): + # transmit + overlap = object_patches[s] * shifted_probes[s] + + # propagate + if s + 1 < self._num_slices: + shifted_probes[s + 1] = self._propagate_array( + overlap, self._propagator_arrays[s] + ) + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(probe) ** 2, + positions_px, + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object[s] += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves), + positions_px, + ) + * probe_normalization + ) + else: + current_object[s] += step_size * ( + self._sum_overlapping_patches_bincounts( + xp.conj(probe) * exit_waves, positions_px + ) + * probe_normalization + ) + + # back-transmit + exit_waves *= xp.conj(obj) + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = self._sum_overlapping_patches_bincounts( + xp.abs(probe) ** 2, + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + if self._object_type == "potential": + current_object[s] = ( + self._sum_overlapping_patches_bincounts( + xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy), + positions_px, + ) + * probe_normalization + ) + else: + current_object[s] = ( + self._sum_overlapping_patches_bincounts( + xp.conj(probe) * exit_waves_copy, + positions_px, + ) + * probe_normalization + ) + + # back-transmit + exit_waves_copy *= xp.conj(obj) + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization + ) + + return current_object, current_probe + + def show_transmitted_probe( + self, + max_batch_size=None, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + max_batch_size: int, optional + Max number of probes to calculate at once + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + kwargs: + Passed to show_complex + """ + + xp = self._xp + xp_storage = self._xp_storage + asnumpy = self._asnumpy + + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + mean_transmitted = xp.zeros_like(self._probe) + intensities_compare = [np.inf, 0] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + # overlaps + shifted_probes = self._return_shifted_probes( + self._probe, positions_px_fractional + ) + _, _, overlap = self._overlap_projection( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ) + + # store relevant arrays + mean_transmitted += overlap.sum(0) + + intensities = xp.sum(xp.abs(overlap) ** 2, axis=(-2, -1)) + min_intensity = intensities.min() + max_intensity = intensities.max() + + if min_intensity < intensities_compare[0]: + min_intensity_transmitted = overlap[xp.argmin(intensities)] + intensities_compare[0] = min_intensity + + if max_intensity > intensities_compare[1]: + max_intensity_transmitted = overlap[xp.argmax(intensities)] + intensities_compare[1] = max_intensity + + mean_transmitted /= self._num_diffraction_patterns + + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean transmitted probe", + "Min-intensity transmitted probe", + "Max-intensity transmitted probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean transmitted Fourier probe", + "Min-intensity transmitted Fourier probe", + "Max-intensity transmitted Fourier probe", + ] + + title = kwargs.get("title", title) + ticks = kwargs.get("ticks", False) + axsize = kwargs.get("axsize", (4, 4)) + + show_complex( + probes, + title=title, + ticks=ticks, + axsize=axsize, + **kwargs, + ) + + self.clear_device_mem(self._device, self._clear_fft_cache) + + +class ObjectNDProbeMixedMethodsMixin: + """ + Mixin class for methods applicable to 2D, 2.5D, and 3D objects using mixed probes. + Overwrites ObjectNDProbeMethodsMixin. + """ + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + shifted_probes:np.ndarray + fractionally-shifted probes + object_patches: np.ndarray + Patched object view + overlap: np.ndarray + shifted_probes * object_patches + """ + + xp = self._xp + + object_patches = current_object[ + vectorized_patch_indices_row, vectorized_patch_indices_col + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + overlap = shifted_probes * xp.expand_dims(object_patches, axis=1) + + return shifted_probes, object_patches, overlap + + def _return_farfield_amplitudes(self, fourier_overlap): + """Small utility to de-duplicate mixed-state Fourier projection.""" + + xp = self._xp + return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + fourier_mask: np.ndarray + Mask to apply at the detector-plane for zeroing-out unreliable gradients + Useful when detector has artifacts such as dead-pixels + + Returns + -------- + exit_waves:np.ndarray + Difference between modified and estimated exit waves + error: float + Reconstruction error + """ + + xp = self._xp + + fourier_overlap = xp.fft.fft2(overlap) + + # resample to match data, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_overlap = bilinear_resample( + fourier_overlap, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + fourier_overlap *= fourier_mask + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap + + fourier_modified_overlap = ( + fourier_modified_overlap - fourier_overlap + ) * fourier_mask + + # resample back to region_of_interest_shape, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_modified_overlap = bilinear_resample( + fourier_modified_overlap, + output_size=self._region_of_interest_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + exit_waves = xp.fft.ifft2(fourier_modified_overlap) + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + overlap, + exit_waves, + fourier_mask, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + overlap: np.ndarray + object * probe overlap + exit_waves: np.ndarray + previously estimated exit waves + fourier_mask: np.ndarray + Mask to apply at the detector-plane for zeroing-out unreliable gradients + Useful when detector has artifacts such as dead-pixels + Currently not implemented for projection sets + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + if fourier_mask is not None: + raise NotImplementedError() + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = overlap.copy() + + factor_to_be_projected = projection_c * overlap + projection_y * exit_waves + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + # resample to match data, note: this needs to happen in reciprocal-space + if self._resample_exit_waves: + fourier_projected_factor = bilinear_resample( + fourier_projected_factor, + output_size=self._amplitudes_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + farfield_amplitudes = self._return_farfield_amplitudes(fourier_projected_factor) + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes + + fourier_projected_factor *= amplitude_modification[:, None] + + # resample back to region_of_interest_shape, note: this needs to happen in real-space + if self._resample_exit_waves: + fourier_projected_factor = bilinear_resample( + fourier_projected_factor, + output_size=self._region_of_interest_shape, + vectorized=True, + conserve_array_sums=True, + xp=xp, + ) + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * overlap + + projection_b * projected_factor + ) + + return exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = xp.zeros_like(current_object) + object_update = xp.zeros_like(current_object) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes[:, i_probe]) ** 2, + positions_px, + ) + if self._object_type == "potential": + object_update += step_size * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes[:, i_probe]) + * exit_waves[:, i_probe] + ), + positions_px, + ) + else: + object_update += step_size * self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe], + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object += object_update * probe_normalization + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += step_size * ( + xp.sum( + xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + shifted_probes:np.ndarray + fractionally-shifted probes + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + probe_normalization = xp.zeros_like(current_object) + current_object = xp.zeros_like(current_object) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes[:, i_probe]) ** 2, + positions_px, + ) + if self._object_type == "potential": + current_object += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(object_patches) + * xp.conj(shifted_probes[:, i_probe]) + * exit_waves[:, i_probe] + ), + positions_px, + ) + else: + current_object += self._sum_overlapping_patches_bincounts( + xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe], + positions_px, + ) + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object *= probe_normalization + + if not fix_probe: + object_normalization = xp.sum( + (xp.abs(object_patches) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + +class Object2p5DProbeMixedMethodsMixin: + """ + Mixin class for methods unique to 2.5D objects using mixed probes. + Overwrites ObjectNDProbeMethodsMixin and ObjectNDProbeMixedMethodsMixin. + """ + + def _overlap_projection( + self, + current_object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes_in, + ): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + object_patches = current_object[ + :, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ] + + if self._object_type == "potential": + object_patches = xp.exp(1j * object_patches) + + num_probe_positions = object_patches.shape[1] + + shifted_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + + shifted_probes = xp.empty(shifted_shape, dtype=object_patches.dtype) + shifted_probes[0] = shifted_probes_in + + for s in range(self._num_slices): + # transmit + overlap = xp.expand_dims(object_patches[s], axis=1) * shifted_probes[s] + + # propagate + if s + 1 < self._num_slices: + shifted_probes[s + 1] = self._propagate_array( + overlap, self._propagator_arrays[s] + ) + + return shifted_probes, object_patches, overlap + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2, + positions_px, + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ), + positions_px, + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe], + positions_px, + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + shifted_probes, + positions_px, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = shifted_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2, + positions_px, + ) + + if self._object_type == "potential": + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ), + positions_px, + ) + else: + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe], + positions_px, + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims(xp.conj(obj), axis=1) + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def show_transmitted_probe( + self, + **kwargs, + ): + raise NotImplementedError() + + +class MultipleMeasurementsMethodsMixin: + """ + Mixin class for methods unique to classes with multiple measurements. + Overwrites various Mixins. + """ + + def _reset_reconstruction( + self, + store_iterations, + reset, + use_projection_scheme, + ): + """ """ + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + # reset can be True, False, or None (default) + if reset is True: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probes_all = [pr.copy() for pr in self._probes_all_initial] + self._positions_px_all = self._positions_px_initial_all.copy() + self._object_type = self._object_type_initial + + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None + + # delete positions affine transform + if hasattr(self, "_tf"): + del self._tf + + elif reset is None: + # continued run + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + + # first start + else: + self.error_iterations = [] + if use_projection_scheme: + self._exit_waves = [None] * len(self._probes_all) + else: + self._exit_waves = None + + def _return_single_probe(self, probe=None): + """Current probe estimate""" + xp = self._xp + + if probe is not None: + _probes = [xp.asarray(pr) for pr in probe] + else: + if not hasattr(self, "_probes_all"): + return None + _probes = self._probes_all + + probe = xp.zeros(self._region_of_interest_shape, dtype=np.complex64) + + for pr in _probes: + probe += pr + + return probe / len(_probes) + + def _return_average_positions( + self, positions=None, cum_probes_per_measurement=None + ): + """Average positions estimate""" + xp_storage = self._xp_storage + + if positions is not None: + _pos = xp_storage.asarray(positions) + else: + if not hasattr(self, "_positions_px_all"): + return None + _pos = self._positions_px_all + + if cum_probes_per_measurement is None: + cum_probes_per_measurement = self._cum_probes_per_measurement + + num_probes_per_measurement = np.diff(cum_probes_per_measurement) + num_measurements = len(num_probes_per_measurement) + + if np.any(num_probes_per_measurement != num_probes_per_measurement[0]): + return None + + avg_positions = xp_storage.zeros( + (num_probes_per_measurement[0], 2), dtype=xp_storage.float32 + ) + + for index in range(num_measurements): + start_idx = cum_probes_per_measurement[index] + end_idx = cum_probes_per_measurement[index + 1] + avg_positions += _pos[start_idx:end_idx] + + return avg_positions / num_measurements + + def _return_self_consistency_errors( + self, + **kwargs, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + @property + def probe_fourier(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [asnumpy(self._return_fourier_probe(pr)) for pr in self._probes_all] + + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [ + asnumpy( + self._return_fourier_probe(pr, remove_initial_probe_aberrations=True) + ) + for pr in self._probes_all + ] + + @property + def probe_centered(self): + """Current probe estimate shifted to the center""" + if not hasattr(self, "_probes_all"): + return None + + asnumpy = self._asnumpy + return [asnumpy(self._return_centered_probe(pr)) for pr in self._probes_all] + + @property + def positions(self): + """Probe positions [A]""" + + if self.angular_sampling is None: + return None + + asnumpy = self._asnumpy + positions_all = [] + + for index in range(self._num_measurements): + start_idx = self._cum_probes_per_measurement[index] + end_idx = self._cum_probes_per_measurement[index + 1] + positions = self._positions_px_all[start_idx:end_idx].copy() + positions[:, 0] *= self.sampling[0] + positions[:, 1] *= self.sampling[1] + positions_all.append(asnumpy(positions)) + + return np.asarray(positions_all) diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py new file mode 100644 index 000000000..f3b2991ab --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -0,0 +1,1256 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely joint ptychographic tomography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + Object2p5DConstraintsMixin, + Object3DConstraintsMixin, + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + MultipleMeasurementsMethodsMixin, + Object2p5DMethodsMixin, + Object2p5DProbeMethodsMixin, + Object3DMethodsMixin, + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class PtychographicTomography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + Object3DConstraintsMixin, + Object2p5DConstraintsMixin, + ObjectNDConstraintsMixin, + MultipleMeasurementsMethodsMixin, + Object2p5DProbeMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + Object3DMethodsMixin, + Object2p5DMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Ptychographic Tomography Reconstruction Class. + + List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (Px,Py,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py,Py) is the padded-object electrostatic potential volume, + where x-axis is the tilt. + + Parameters + ---------- + datacube: List of DataCubes + Input list of 4D diffraction pattern intensities + energy: float + The electron energy of the wave functions in eV + num_slices: int + Number of super-slices to use in the forward model + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py,Py) + If None, initialized to 1.0 + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: list of np.ndarray, optional + Probe positions in Å for each diffraction intensity per tilt + If None, initialized to a grid scan centered along tilt axis + positions_offset_ang: list of np.ndarray, optional + Offset of positions in A + verbose: bool, optional + If True, class methods will inherit this and print additional information + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions to ignore in reconstruction + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ( + "_num_slices", + "_tilt_orientation_matrices", + "_num_measurements", + ) + + def __init__( + self, + energy: float, + num_slices: int, + tilt_orientation_matrices: Sequence[np.ndarray], + datacube: Sequence[DataCube] = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + object_type: str = "potential", + positions_mask: np.ndarray = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: Sequence[np.ndarray] = None, + positions_offset_ang: Sequence[np.ndarray] = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "ptychographic-tomography_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + num_tilts = len(tilt_orientation_matrices) + if initial_scan_positions is None: + initial_scan_positions = [None] * num_tilts + + if object_type != "potential": + raise NotImplementedError() + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe_init = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + self._num_slices = num_slices + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) + self._num_measurements = num_tilts + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_probe_overlaps: bool = True, + rotation_real_space_degrees: float = None, + diffraction_patterns_rotate_degrees: float = None, + diffraction_patterns_transpose: bool = None, + force_com_shifts: Sequence[float] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + progress_bar: bool = True, + object_fov_mask: np.ndarray = True, + crop_patterns: bool = False, + main_tilt_axis: str = "vertical", + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + + Additionally, it initializes an (Px,Py, Py) array of 1.0 + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + rotation_real_space_degrees: float (degrees), optional + In plane rotation around z axis between x axis and tilt axis in + real space (forced to be in xy plane) + diffraction_patterns_rotate_degrees: float, optional + Relative rotation angle between real and reciprocal space + diffraction_patterns_transpose: bool, optional + Whether diffraction intensities need to be transposed. + force_com_shifts: list of tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. One tuple per tilt. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + If True, crop patterns to avoid wrap around of patterns when centering + main_tilt_axis: str + The default, 'vertical' (first scan dimension), results in object size (q,p,q), + 'horizontal' (second scan dimension) results in object size (p,p,q), + any other value (e.g. None) results in object size (max(p,q),p,q). + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: OverlapTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_measurements, 1, 1) + ) + + num_probes_per_measurement = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + else: + self._positions_mask = [None] * self._num_measurements + num_probes_per_measurement = [0] + [dc.R_N for dc in self._datacube] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + # prepopulate relevant arrays + self._mean_diffraction_intensity = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + # calculate roi_shape + roi_shape = self._datacube[0].Qshape + if diffraction_intensities_shape is not None: + roi_shape = diffraction_intensities_shape + if padded_diffraction_intensities_shape is not None: + roi_shape = tuple( + max(q, s) + for q, s in zip(roi_shape, padded_diffraction_intensities_shape) + ) + + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + roi_shape + ) + + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # TO-DO: generalize this + if force_com_shifts is None: + force_com_shifts = [None] * self._num_measurements + + if force_com_measured is None: + force_com_measured = [None] * self._num_measurements + + if self._positions_offset_ang is None: + self._positions_offset_ang = [None] * self._num_measurements + + self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees) + self._rotation_best_transpose = diffraction_patterns_transpose + + if progress_bar: + # turn off verbosity to play nice with tqdm + verbose = self._verbose + self._verbose = False + + # loop over DPs for preprocessing + for index in tqdmnd( + self._num_measurements, + desc="Preprocessing data", + unit="tilt", + disable=not progress_bar, + ): + # preprocess datacube, vacuum and masks only for first tilt + if index == 0: + ( + self._datacube[index], + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + else: + ( + self._datacube[index], + _, + _, + force_com_shifts[index], + force_com_measured[index], + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube[index], + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=None, + dp_mask=None, + com_shifts=force_com_shifts[index], + com_measured=force_com_measured[index], + ) + + # calibrations + intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube[index], + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # calculate CoM + ( + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts[index], + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured[index], + ) + + # corner-center amplitudes + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + ( + amplitudes, + mean_diffraction_intensity_temp, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + intensities, + com_fitted_x, + com_fitted_y, + self._positions_mask[index], + crop_patterns, + ) + + self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) + + # explicitly transfer arrays to storage + self._amplitudes[idx_start:idx_end] = copy_to_device(amplitudes, storage) + + del ( + intensities, + amplitudes, + com_measured_x, + com_measured_y, + com_fitted_x, + com_fitted_y, + com_normalized_x, + com_normalized_y, + ) + + # initialize probe positions + ( + self._positions_px_all[idx_start:idx_end], + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions[index], + self._positions_mask[index], + self._object_padding_px, + self._positions_offset_ang[index], + ) + + if progress_bar: + # reset verbosity + self._verbose = verbose + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px_all, + self._object_type, + main_tilt_axis, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + self._num_voxels = self._object.shape[0] + + # center probe positions + self._positions_px_all = xp_storage.asarray( + self._positions_px_all, dtype=xp_storage.float32 + ) + + for index in range(self._num_measurements): + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + + positions_px = self._positions_px_all[idx_start:idx_end] + positions_px_com = positions_px.mean(0) + positions_px -= positions_px_com - xp_storage.array(self._object_shape) / 2 + self._positions_px_all[idx_start:idx_end] = positions_px.copy() + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probes_all = [] + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + list_Q = isinstance(self._probe_init, (list, tuple)) + + for index in range(self._num_measurements): + _probe, self._semiangle_cutoff = self._initialize_probe( + self._probe_init[index] if list_Q else self._probe_init, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity[index], + self._semiangle_cutoff, + crop_patterns, + ) + + self._probes_all.append(_probe) + self._probes_all_initial.append(_probe.copy()) + self._probes_all_initial_aperture.append(xp.abs(xp.fft.fft2(_probe))) + + del self._probe_init + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + if main_tilt_axis == "vertical": + thickness = self._object_shape[1] * self.sampling[1] + elif main_tilt_axis == "horizontal": + thickness = self._object_shape[0] * self.sampling[0] + else: + thickness_h = self._object_shape[1] * self.sampling[1] + thickness_v = self._object_shape[0] * self.sampling[0] + thickness = max(thickness_h, thickness_v) + + self._slice_thicknesses = np.tile( + thickness / self._num_slices, self._num_slices - 1 + ) + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + if object_fov_mask is not True: + raise NotImplementedError() + else: + self._object_fov_mask = np.full(self._object_shape, True) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._cum_probes_per_measurement[1], max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px_all[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift( + self._probes_all[0], positions_px_fractional, xp + ) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + probe_overlap = asnumpy(probe_overlap) + + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=power, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probes_all[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, + chroma_boost=chroma_boost, + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe intensity") + + ax3.imshow( + probe_overlap, + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[0, :, 1], + self.positions[0, :, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + detector_fourier_mask: np.ndarray = None, + tv_denoise: bool = True, + tv_denoise_weights: float = None, + tv_denoise_inner_iter=40, + collective_measurement_updates: bool = True, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + object_positivity: bool, optional + If True, forces object to be positive + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_measurement_updates: bool + if True perform collective measurement updates (i.e. one per tilt) + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to apply at the detector-plane for zeroing-out unreliable gradients. + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + + Returns + -------- + self: OverlapTomographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probes_all", + "_probes_all_initial", + "_probes_all_initial_aperture", + "_propagator_arrays", + ] + self.copy_attributes_to_device(attrs, device) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + # initialization + self._reset_reconstruction(store_iterations, reset, use_projection_scheme) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is None: + detector_fourier_mask = xp.ones(self._amplitudes[0].shape) + else: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if collective_measurement_updates: + collective_object = xp.zeros_like(self._object) + + indices = np.arange(self._num_measurements) + np.random.shuffle(indices) + + old_rot_matrix = np.eye(3) # identity + + for index in indices: + self._active_measurement_index = index + + measurement_error = 0.0 + + rot_matrix = self._tilt_orientation_matrices[ + self._active_measurement_index + ] + self._object = self._rotate_zxy_volume( + self._object, + rot_matrix @ old_rot_matrix.T, + ) + + object_sliced = self._project_sliced_object( + self._object, self._num_slices + ) + + _probe = self._probes_all[self._active_measurement_index] + _probe_initial_aperture = self._probes_all_initial_aperture[ + self._active_measurement_index + ] + + if not use_projection_scheme: + object_sliced_old = object_sliced.copy() + + start_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + ] + end_idx = self._cum_probes_per_measurement[ + self._active_measurement_index + 1 + ] + + num_diffraction_patterns = end_idx - start_idx + shuffled_indices = np.arange(start_idx, end_idx) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px_all[batch_indices] + positions_px_initial = self._positions_px_initial_all[batch_indices] + positions_px_fractional = positions_px - xp_storage.round( + positions_px + ) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + _probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + object_sliced, _probe = self._adjoint( + object_sliced, + _probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px_all[batch_indices] = ( + self._position_correction( + object_sliced, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + ) + + measurement_error += batch_error + + if not use_projection_scheme: + object_sliced -= object_sliced_old + + object_update = self._expand_sliced_object( + object_sliced, self._num_voxels + ) + + if collective_measurement_updates: + collective_object += self._rotate_zxy_volume( + object_update, rot_matrix.T + ) + else: + self._object += object_update + + old_rot_matrix = rot_matrix + + # Normalize Error + measurement_error /= ( + self._mean_diffraction_intensity[self._active_measurement_index] + * num_diffraction_patterns + ) + error += measurement_error + + # constraints + + if collective_measurement_updates: + # probe and positions + _probe = self._probe_constraints( + _probe, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + ) + + self._positions_px_all[batch_indices] = self._positions_constraints( + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + ) + + else: + # object, probe, and positions + ( + self._object, + _probe, + self._positions_px_all[batch_indices], + ) = self._constraints( + self._object, + _probe, + self._positions_px_all[batch_indices], + self._positions_px_initial_all[batch_indices], + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude + and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=_probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) + + # Normalize Error Over Tilts + error /= self._num_measurements + + if collective_measurement_updates: + self._object += collective_object / self._num_measurements + + # object only + self._object = self._object_constraints( + self._object, + gaussian_filter=gaussian_filter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + tv_denoise=tv_denoise and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/ptychographic_visualizations.py b/py4DSTEM/process/phase/ptychographic_visualizations.py new file mode 100644 index 000000000..0fbc7f2ce --- /dev/null +++ b/py4DSTEM/process/phase/ptychographic_visualizations.py @@ -0,0 +1,846 @@ +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.process.phase.utils import AffineTransform, copy_to_device +from py4DSTEM.visualize.vis_special import ( + Complex2RGB, + add_colorbar_arg, + return_scaled_histogram_ordering, +) + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +class VisualizationsMixin: + """ + Mixin class for various visualization methods. + """ + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool, optional + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + + asnumpy = self._asnumpy + + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + chroma_boost = kwargs.pop("chroma_boost", 1) + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + # get scaled arrays + obj, kwargs = self._return_projected_cropped_potential( + return_kwargs=True, **kwargs + ) + probe = self._return_single_probe() + + obj, vmin, vmax = return_scaled_histogram_ordering(obj, vmin, vmax) + + extent = [ + 0, + self.sampling[1] * obj.shape[1], + self.sampling[0] * obj.shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe = asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB( + probe, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + asnumpy(self._return_centered_probe(probe)), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title("Reconstructed probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + # Object + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin, + vmax=vmax, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + iterations_grid: Tuple[int, int], + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed complex probe is displayed + plot_fourier_probe: bool + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + num_iter = len(self.object_iterations) + + if iterations_grid == "auto": + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + + else: + if plot_probe or plot_fourier_probe: + if iterations_grid[0] != 2: + raise ValueError() + else: + if iterations_grid[0] * iterations_grid[1] > num_iter: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + chroma_boost = kwargs.pop("chroma_boost", 1) + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + # most recent errors + errors = np.array(self.error_iterations)[-num_iter:] + + max_iter = num_iter - 1 + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + probes = [ + self._return_single_probe(self.probe_iterations[idx]) + for idx in grid_range + ] + else: + total_grids = np.prod(iterations_grid) + grid_range = np.arange(0, max_iter + 1, max_iter // (total_grids - 1)) + + objects = [] + + for idx in grid_range: + if idx < grid_range[-1]: + obj = self._return_projected_cropped_potential( + obj=self.object_iterations[idx], + return_kwargs=False, + **kwargs, + ) + else: + obj, kwargs = self._return_projected_cropped_potential( + obj=self.object_iterations[idx], return_kwargs=True, **kwargs + ) + + objects.append(obj) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=( + (1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid + ), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + obj, vmin_n, vmax_n = return_scaled_histogram_ordering( + objects[n], vmin=vmin, vmax=vmax + ) + im = ax.imshow( + obj, + extent=extent, + cmap=cmap, + vmin=vmin_n, + vmax=vmax_n, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} potential") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[n], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + + else: + probe_array = Complex2RGB( + asnumpy(probes[n]), + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} probe intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], + chroma_boost=chroma_boost, + ) + + if plot_convergence: + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, + cbar: bool = True, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + **kwargs, + ) + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def show_updated_positions( + self, + pos=None, + initial_pos=None, + scale_arrows=1, + plot_arrow_freq=None, + plot_cropped_rotated_fov=True, + cbar=True, + verbose=True, + **kwargs, + ): + """ + Function to plot changes to probe positions during ptychography reconstruciton + + Parameters + ---------- + scale_arrows: float, optional + scaling factor to be applied on vectors prior to plt.quiver call + plot_arrow_freq: int, optional + thinning parameter to only plot a subset of probe positions + assumes grid position + verbose: bool, optional + if True, prints AffineTransformation if positions have been updated + """ + + if verbose: + if hasattr(self, "_tf"): + print(self._tf) + + asnumpy = self._asnumpy + + if pos is None: + pos = self.positions + + # handle multiple measurements + if pos.ndim == 3: + pos = pos.mean(0) + + if initial_pos is None: + initial_pos = asnumpy(self._positions_initial) + + if plot_cropped_rotated_fov: + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + initial_pos = tf(initial_pos, origin=np.mean(pos, axis=0)) + pos = tf(pos, origin=np.mean(pos, axis=0)) + + obj_shape = self.object_cropped.shape[-2:] + initial_pos_com = np.mean(initial_pos, axis=0) + center_shift = initial_pos_com - ( + np.array(obj_shape) / 2 * np.array(self.sampling) + ) + initial_pos -= center_shift + pos -= center_shift + + else: + obj_shape = self._object_shape + + if plot_arrow_freq is not None: + rshape = self._datacube.Rshape + (2,) + freq = plot_arrow_freq + + initial_pos = initial_pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + pos = pos.reshape(rshape)[::freq, ::freq].reshape(-1, 2) + + deltas = pos - initial_pos + norms = np.linalg.norm(deltas, axis=1) + + extent = [ + 0, + self.sampling[1] * obj_shape[1], + self.sampling[0] * obj_shape[0], + 0, + ] + + figsize = kwargs.pop("figsize", (4, 4)) + cmap = kwargs.pop("cmap", "Reds") + + fig, ax = plt.subplots(figsize=figsize) + + im = ax.quiver( + initial_pos[:, 1], + initial_pos[:, 0], + deltas[:, 1] * scale_arrows, + deltas[:, 0] * scale_arrows, + norms, + scale_units="xy", + scale=1, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + cb = fig.colorbar(im, cax=ax_cb) + cb.set_label("Δ [A]", rotation=0, ha="left", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.set_aspect("equal") + ax.set_title("Updated probe positions") + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + xp = self._xp + device = self._device + asnumpy = self._asnumpy + gaussian_filter = self._scipy.ndimage.gaussian_filter + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + errors_xp = xp.asarray(errors) + + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + positions_px = copy_to_device(self._positions_px, device) + rotated_points = tf(positions_px, origin=positions_px.mean(0), xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + all_inds = [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + + all_weights = [ + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ] + + pix_count = xp.zeros(pixel_size, dtype=xp.float32) + pix_output = xp.zeros(pixel_size, dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + inds_1D = xp.ravel_multi_index( + inds, + pixel_output, + mode=["wrap", "wrap"], + ) + + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=pixel_size, + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * errors_xp, + minlength=pixel_size, + ) + + # reshape 1D arrays to 2D + pix_count = xp.reshape( + pix_count, + pixel_output, + ) + pix_output = xp.reshape( + pix_output, + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma) + pix_output = gaussian_filter(pix_output, kde_sigma) + sub = pix_count > 1e-3 + pix_output[sub] /= pix_count[sub] + pix_output[np.logical_not(sub)] = 1 + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + asnumpy(pix_output), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized squared error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], + 0, + ] + + ax.imshow( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) + + self.clear_device_mem(self._device, self._clear_fft_cache) diff --git a/py4DSTEM/process/phase/singleslice_ptychography.py b/py4DSTEM/process/phase/singleslice_ptychography.py new file mode 100644 index 000000000..e1f14a90f --- /dev/null +++ b/py4DSTEM/process/phase/singleslice_ptychography.py @@ -0,0 +1,971 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely (single-slice) ptychography. +""" + +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg + +try: + import cupy as cp +except (ImportError, ModuleNotFoundError): + cp = np + +from emdfile import Custom, tqdmnd +from py4DSTEM.datacube import DataCube +from py4DSTEM.process.phase.phase_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.ptychographic_constraints import ( + ObjectNDConstraintsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, +) +from py4DSTEM.process.phase.ptychographic_methods import ( + ObjectNDMethodsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, +) +from py4DSTEM.process.phase.ptychographic_visualizations import VisualizationsMixin +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + copy_to_device, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, +) + + +class SingleslicePtychography( + VisualizationsMixin, + PositionsConstraintsMixin, + ProbeConstraintsMixin, + ObjectNDConstraintsMixin, + ObjectNDProbeMethodsMixin, + ProbeMethodsMixin, + ObjectNDMethodsMixin, + PtychographicReconstruction, +): + """ + Iterative Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (Sx,Sy) + Reconstructed object dimensions : (Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our probe + and (Px,Py) is the padded-object size we position our ROI around in. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + datacube: DataCube + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + positions_offset_ang: np.ndarray, optional + Offset of positions in A + verbose: bool, optional + If True, class methods will inherit this and print additional information + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + device: str, optional + Device calculation will be perfomed on. Must be 'cpu' or 'gpu' + storage: str, optional + Device non-frequent arrays will be stored on. Must be 'cpu' or 'gpu' + clear_fft_cache: bool, optional + If True, and device = 'gpu', clears the cached fft plan at the end of function calls + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = () + + def __init__( + self, + energy: float, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + positions_offset_ang: np.ndarray = None, + object_padding_px: Tuple[int, int] = None, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + storage: str = None, + clear_fft_cache: bool = True, + name: str = "ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if storage is None: + storage = device + + self.set_device(device, clear_fft_cache) + self.set_storage(storage) + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._positions_offset_ang = positions_offset_ang + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._positions_mask = positions_mask + self._verbose = verbose + self._preprocessed = False + + # Class-specific Metadata + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "bilinear", + padded_diffraction_intensities_shape: Tuple[int, int] = None, + region_of_interest_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = None, + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None, + force_com_measured: Sequence[np.ndarray] = None, + vectorized_com_calculation: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + device: str = None, + clear_fft_cache: bool = None, + max_batch_size: int = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + padded_diffraction_intensities_shape: (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + region_of_interest_shape: (int,int), optional + If not None, explicitly sets region_of_interest_shape and resamples exit_waves + at the diffraction plane to allow comparison with experimental data + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts + vectorized_com_calculation: bool, optional + If True (default), the memory-intensive CoM calculation is vectorized + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + device: str, optional + if not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + if true, and device = 'gpu', clears the cached fft plan at the end of function calls + max_batch_size: int, optional + Max number of probes to use at once in computing probe overlaps + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + # handle device/storage + self.set_device(device, clear_fft_cache) + + xp = self._xp + device = self._device + xp_storage = self._xp_storage + storage = self._storage + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._padded_diffraction_intensities_shape = ( + padded_diffraction_intensities_shape + ) + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + # preprocess datacube + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + force_com_measured, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + padded_diffraction_intensities_shape=self._padded_diffraction_intensities_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + com_measured=force_com_measured, + ) + + # calibrations + _intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # calculate CoM + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + _intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + vectorized_calculation=vectorized_com_calculation, + com_measured=force_com_measured, + ) + + # estimate rotation / transpose + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + # explicitly transfer arrays to storage + attrs = [ + "_com_measured_x", + "_com_measured_y", + "_com_fitted_x", + "_com_fitted_y", + "_com_normalized_x", + "_com_normalized_y", + "_com_x", + "_com_y", + ] + self.copy_attributes_to_device(attrs, storage) + + # corner-center amplitudes + ( + self._amplitudes, + self._mean_diffraction_intensity, + self._crop_mask, + ) = self._normalize_diffraction_intensities( + _intensities, + self._com_fitted_x, + self._com_fitted_y, + self._positions_mask, + crop_patterns, + ) + + # explicitly transfer arrays to storage + self._amplitudes = copy_to_device(self._amplitudes, storage) + del _intensities + + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._amplitudes_shape = np.array(self._amplitudes.shape[-2:]) + + if region_of_interest_shape is not None: + self._resample_exit_waves = True + self._region_of_interest_shape = np.array(region_of_interest_shape) + else: + self._resample_exit_waves = False + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + + # initialize probe positions + ( + self._positions_px, + self._object_padding_px, + ) = self._calculate_scan_positions_in_pixels( + self._scan_positions, + self._positions_mask, + self._object_padding_px, + self._positions_offset_ang, + ) + + # initialize object + self._object = self._initialize_object( + self._object, + self._positions_px, + self._object_type, + ) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape + + # center probe positions + self._positions_px = xp_storage.asarray( + self._positions_px, dtype=xp_storage.float32 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + self._positions_px -= ( + self._positions_px_initial_com - xp_storage.array(self._object_shape) / 2 + ) + self._positions_px_initial_com = self._positions_px.mean(0) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # initialize probe + self._probe, self._semiangle_cutoff = self._initialize_probe( + self._probe, + self._vacuum_probe_intensity, + self._mean_diffraction_intensity, + self._semiangle_cutoff, + crop_patterns, + ) + + # initialize aberrations + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=device, + )._evaluate_ctf() + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) + + if object_fov_mask is None or plot_probe_overlaps: + # overlaps + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + probe_overlap = xp.zeros(self._object_shape, dtype=xp.float32) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + positions_px = self._positions_px[start:end] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + shifted_probes = fft_shift(self._probe, positions_px_fractional, xp) + probe_overlap += self._sum_overlapping_patches_bincounts( + xp.abs(shifted_probes) ** 2, positions_px + ) + + del shifted_probes + + # initialize object_fov_mask + if object_fov_mask is None: + gaussian_filter = self._scipy.ndimage.gaussian_filter + probe_overlap_blurred = gaussian_filter(probe_overlap, 1.0) + self._object_fov_mask = asnumpy( + probe_overlap_blurred > 0.25 * probe_overlap_blurred.max() + ) + del probe_overlap_blurred + elif object_fov_mask is True: + self._object_fov_mask = np.full(self._object_shape, True) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + # plot probe overlaps + if plot_probe_overlaps: + probe_overlap = asnumpy(probe_overlap) + + figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + power = kwargs.pop("power", 2) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered, + power=power, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax1, chroma_boost=chroma_boost) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe intensity") + + ax2.imshow( + probe_overlap, + extent=extent, + cmap="gray", + ) + ax2.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_xlim((extent[0], extent[1])) + ax2.set_ylim((extent[2], extent[3])) + ax2.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self + + def reconstruct( + self, + num_iter: int = 8, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.5, + pure_phase_object: bool = False, + fix_probe_com: bool = True, + fix_probe: bool = False, + fix_probe_aperture: bool = False, + constrain_probe_amplitude: bool = False, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude: bool = False, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions: bool = True, + fix_positions_com: bool = True, + max_position_update_distance: float = None, + max_position_total_distance: float = None, + global_affine_transformation: bool = False, + gaussian_filter_sigma: float = None, + gaussian_filter: bool = True, + fit_probe_aberrations: bool = False, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + fit_probe_aberrations_remove_initial: bool = False, + fit_probe_aberrations_using_scikit_image: bool = True, + butterworth_filter: bool = True, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + tv_denoise: bool = True, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + detector_fourier_mask: np.ndarray = None, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + device: str = None, + clear_fft_cache: bool = None, + object_type: str = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + num_iter: int, optional + Number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + pure_phase_object: bool, optional + If True, object amplitude is set to unity + fix_probe_com: bool, optional + If True, fixes center of mass of probe + fix_probe: bool, optional + If True, probe is fixed + fix_probe_aperture: bool, optional + If True, vaccum probe is used to fix Fourier amplitude + constrain_probe_amplitude: bool, optional + If True, real-space probe is constrained with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool, optional + If True, Fourier-probe is constrained by fitting a sigmoid for each angular frequency + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions: bool, optional + If True, probe-positions are fixed + fix_positions_com: bool, optional + If True, fixes the positions CoM to the middle of the fov + max_position_update_distance: float, optional + Maximum allowed distance for update in A + max_position_total_distance: float, optional + Maximum allowed distance from initial positions + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter: bool, optional + If True and gaussian_filter_sigma is not None, object is smoothed using gaussian filtering + fit_probe_aberrations: bool, optional + If True, probe aberrations are fitted to a low-order expansion + fit_probe_aberrations_max_angular_order: int + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: int + Max radial order of probe aberrations basis functions + fit_probe_aberrations_remove_initial: bool + If true, initial probe aberrations are removed before fitting + fit_probe_aberrations_using_scikit_image: bool + If true, the necessary phase unwrapping is performed using scikit-image. This is more stable, but occasionally leads + to a documented bug where the kernel hangs.. + If false, a poisson-based solver is used for phase unwrapping. This won't hang, but tends to underestimate aberrations. + butterworth_filter: bool, optional + If True and q_lowpass or q_highpass is not None, object is smoothed using butterworth filtering + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool, optional + If True and tv_denoise_weight is not None, object is smoothed using TV denoising + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + detector_fourier_mask: np.ndarray + Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels). + Useful when detector has artifacts such as dead-pixels. Usually binary. + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + device: str, optional + If not none, overwrites self._device to set device preprocess will be perfomed on. + clear_fft_cache: bool, optional + If true, and device = 'gpu', clears the cached fft plan at the end of function calls + object_type: str, optional + Overwrites self._object_type + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + # handle device/storage + self.set_device(device, clear_fft_cache) + + if device is not None: + attrs = [ + "_known_aberrations_array", + "_object", + "_object_initial", + "_probe", + "_probe_initial", + "_probe_initial_aperture", + ] + self.copy_attributes_to_device(attrs, device) + + # initialization + self._reset_reconstruction(store_iterations, reset) + + if object_type is not None: + self._switch_object_type(object_type) + + xp = self._xp + xp_storage = self._xp_storage + device = self._device + asnumpy = self._asnumpy + + # set and report reconstruction method + ( + use_projection_scheme, + projection_a, + projection_b, + projection_c, + reconstruction_parameter, + step_size, + ) = self._set_reconstruction_method_parameters( + reconstruction_method, + reconstruction_parameter, + reconstruction_parameter_a, + reconstruction_parameter_b, + reconstruction_parameter_c, + step_size, + ) + + if self._verbose: + self._report_reconstruction_summary( + num_iter, + use_projection_scheme, + reconstruction_method, + reconstruction_parameter, + projection_a, + projection_b, + projection_c, + normalization_min, + max_batch_size, + step_size, + ) + + # batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + + if max_batch_size is not None: + np.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + if detector_fourier_mask is None: + detector_fourier_mask = xp.ones(self._amplitudes[0].shape) + else: + detector_fourier_mask = xp.asarray(detector_fourier_mask) + + # main loop + for a0 in tqdmnd( + num_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + batch_indices = shuffled_indices[start:end] + positions_px = self._positions_px[batch_indices] + positions_px_initial = self._positions_px_initial[batch_indices] + positions_px_fractional = positions_px - xp_storage.round(positions_px) + + ( + vectorized_patch_indices_row, + vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices(positions_px) + + amplitudes_device = copy_to_device( + self._amplitudes[batch_indices], device + ) + + # forward operator + ( + shifted_probes, + object_patches, + overlap, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + self._probe, + positions_px_fractional, + amplitudes_device, + self._exit_waves, + detector_fourier_mask, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + shifted_probes, + positions_px, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=fix_probe, + ) + + # position correction + if not fix_positions: + self._positions_px[batch_indices] = self._position_correction( + self._object, + vectorized_patch_indices_row, + vectorized_patch_indices_col, + shifted_probes, + overlap, + amplitudes_device, + positions_px, + positions_px_initial, + positions_step_size, + max_position_update_distance, + max_position_total_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + self._positions_px_initial, + fix_probe_com=fix_probe_com and not fix_probe, + constrain_probe_amplitude=constrain_probe_amplitude and not fix_probe, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=constrain_probe_fourier_amplitude + and not fix_probe, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=fit_probe_aberrations and not fix_probe, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fit_probe_aberrations_remove_initial=fit_probe_aberrations_remove_initial, + fit_probe_aberrations_using_scikit_image=fit_probe_aberrations_using_scikit_image, + fix_probe_aperture=fix_probe_aperture and not fix_probe, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=fix_positions, + fix_positions_com=fix_positions_com and not fix_positions, + global_affine_transformation=global_affine_transformation, + gaussian_filter=gaussian_filter and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=butterworth_filter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + tv_denoise=tv_denoise and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=( + self._object_fov_mask_inverse + if fix_potential_baseline + and self._object_fov_mask_inverse.sum() > 0 + else None + ), + pure_phase_object=pure_phase_object and self._object_type == "complex", + ) + + self.error_iterations.append(error.item()) + + if store_iterations: + self.object_iterations.append(asnumpy(self._object).copy()) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + # remove _exit_waves attr from self for GD + if not use_projection_scheme: + self._exit_waves = None + + self.clear_device_mem(self._device, self._clear_fft_cache) + + return self diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index fc1b59a07..a5a541795 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -3,19 +3,25 @@ import matplotlib.pyplot as plt import numpy as np +from scipy.fft import dctn, idctn +from scipy.ndimage import gaussian_filter, uniform_filter1d, zoom from scipy.optimize import curve_fit try: import cupy as cp - from cupyx.scipy.fft import rfft + from cupyx.scipy.ndimage import zoom as zoom_cp + + get_array_module = cp.get_array_module except (ImportError, ModuleNotFoundError): cp = None - from scipy.fft import dstn, idstn + + def get_array_module(*args): + return np + from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from scipy.ndimage import gaussian_filter, uniform_filter1d from skimage.restoration import unwrap_phase # fmt: off @@ -197,9 +203,7 @@ def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: xp = self._xp - return xp.exp( - -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2 - ) + return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2) def evaluate_spatial_envelope( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] @@ -404,16 +408,14 @@ def get_scattering_angles(self): def get_spatial_frequencies(self): xp = self._xp - kx, ky = spatial_frequencies(self._gpts, self._sampling) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) + kx, ky = spatial_frequencies(self._gpts, self._sampling, xp) return kx, ky def polar_coordinates(self, x, y): """Calculate a polar grid for a given Cartesian grid.""" xp = self._xp - alpha = xp.sqrt(x.reshape((-1, 1)) ** 2 + y.reshape((1, -1)) ** 2) - phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1))) + alpha = xp.sqrt(x[:, None] ** 2 + y[None, :] ** 2) + phi = xp.arctan2(y[None, :], x[:, None]) return alpha, phi def build(self): @@ -440,7 +442,7 @@ def visualize(self, **kwargs): return self -def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]): +def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float], xp=np): """ Calculate spatial frequencies of a grid. @@ -457,7 +459,7 @@ def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]): """ return tuple( - np.fft.fftfreq(n, d).astype(np.float32) for n, d in zip(gpts, sampling) + xp.fft.fftfreq(n, d).astype(xp.float32) for n, d in zip(gpts, sampling) ) @@ -489,16 +491,14 @@ def fourier_translation_operator( if len(positions_shape) == 1: positions = positions[None] - kx, ky = spatial_frequencies(shape, (1.0, 1.0)) - kx = kx.reshape((1, -1, 1)) - ky = ky.reshape((1, 1, -1)) - kx = xp.asarray(kx, dtype=xp.float32) - ky = xp.asarray(ky, dtype=xp.float32) + kx, ky = spatial_frequencies(shape, (1.0, 1.0), xp=xp) positions = xp.asarray(positions, dtype=xp.float32) - x = positions[:, 0].reshape((-1,) + (1, 1)) - y = positions[:, 1].reshape((-1,) + (1, 1)) + x = positions[:, 0].ravel()[:, None, None] + y = positions[:, 1].ravel()[:, None, None] - result = xp.exp(-2.0j * np.pi * kx * x) * xp.exp(-2.0j * np.pi * ky * y) + result = xp.exp(-2.0j * np.pi * kx[None, :, None] * x) * xp.exp( + -2.0j * np.pi * ky[None, None, :] * y + ) if len(positions_shape) == 1: return result[0] @@ -1150,114 +1150,79 @@ def fourier_rotate_real_volume(array, angle, axes=(0, 1), xp=np): return output_arr +def array_slice(axis, ndim, start, end, step=1): + """Returns array slice along dynamic axis""" + return (slice(None),) * (axis % ndim) + (slice(start, end, step),) + + ### Divergence Projection Functions -def compute_divergence(vector_field, spacings, xp=np): +def periodic_centered_difference(array, spacing, axis, xp=np): + """Computes second-order centered difference with periodic BCs""" + return (xp.roll(array, -1, axis=axis) - xp.roll(array, 1, axis=axis)) / ( + 2 * spacing + ) + + +def compute_divergence_periodic(vector_field, spacings, xp=np): """Computes divergence of vector_field""" num_dims = len(spacings) div = xp.zeros_like(vector_field[0]) for i in range(num_dims): - div += xp.gradient(vector_field[i], spacings[i], axis=i) + div += periodic_centered_difference(vector_field[i], spacings[i], axis=i, xp=xp) return div -def compute_gradient(scalar_field, spacings, xp=np): +def compute_gradient_periodic(scalar_field, spacings, xp=np): """Computes gradient of scalar_field""" num_dims = len(spacings) grad = xp.zeros((num_dims,) + scalar_field.shape) for i in range(num_dims): - grad[i] = xp.gradient(scalar_field, spacings[i], axis=i) + grad[i] = periodic_centered_difference(scalar_field, spacings[i], axis=i, xp=xp) return grad -def array_slice(axis, ndim, start, end, step=1): - """Returns array slice along dynamic axis""" - return (slice(None),) * (axis % ndim) + (slice(start, end, step),) - - -def make_array_rfft_compatible(array_nd, axis=0, xp=np): - """Expand array to be rfft compatible""" - array_shape = np.array(array_nd.shape) - d = array_nd.ndim - n = array_shape[axis] - array_shape[axis] = (n + 1) * 2 +def preconditioned_laplacian_periodic_3D(shape, xp=np): + """FFT eigenvalues""" + n, m, p = shape + i, j, k = xp.ogrid[0:n, 0:m, 0:p] - dtype = array_nd.dtype - padded_array = xp.zeros(array_shape, dtype=dtype) - - padded_array[array_slice(axis, d, 1, n + 1)] = -array_nd - padded_array[array_slice(axis, d, None, -n - 1, -1)] = array_nd - - return padded_array - - -def dst_I(array_nd, xp=np): - """1D rfft-based DST-I""" - d = array_nd.ndim - for axis in range(d): - crop_slice = array_slice(axis, d, 1, -1) - array_nd = rfft( - make_array_rfft_compatible(array_nd, axis=axis, xp=xp), axis=axis - )[crop_slice].imag - - return array_nd - - -def idst_I(array_nd, xp=np): - """1D rfft-based iDST-I""" - scaling = np.prod((np.array(array_nd.shape) + 1) * 2) - return dst_I(array_nd, xp=xp) / scaling - - -def preconditioned_laplacian(num_exterior, spacing=1, xp=np): - """DST-I eigenvalues""" - n = num_exterior - 1 - evals_1d = 2 - 2 * xp.cos(np.pi * xp.arange(1, num_exterior) / num_exterior) - - op = ( - xp.repeat(evals_1d, n**2) - + xp.tile(evals_1d, n**2) - + xp.tile(xp.repeat(evals_1d, n), n) + op = 6 - 2 * xp.cos(2 * np.pi * i / n) * xp.cos(2 * np.pi * j / m) * xp.cos( + 2 * np.pi * k / p ) + op[0, 0, 0] = 1 # gauge invariance + return -op - return -op / spacing**2 +def preconditioned_poisson_solver_periodic_3D(rhs, gauge=None, xp=np): + """FFT based poisson solver""" + op = preconditioned_laplacian_periodic_3D(rhs.shape, xp=xp) -def preconditioned_poisson_solver(rhs_interior, spacing=1, xp=np): - """DST-I based poisson solver""" - nx, ny, nz = rhs_interior.shape - if nx != ny or nx != nz: - raise ValueError() - - op = preconditioned_laplacian(nx + 1, spacing=spacing, xp=xp) - if xp is np: - dst_rhs = dstn(rhs_interior, type=1).ravel() - dst_u = (dst_rhs / op).reshape((nx, ny, nz)) - sol = idstn(dst_u, type=1) - else: - dst_rhs = dst_I(rhs_interior, xp=xp).ravel() - dst_u = (dst_rhs / op).reshape((nx, ny, nz)) - sol = idst_I(dst_u, xp=xp) + if gauge is None: + gauge = xp.mean(rhs) + fft_rhs = xp.fft.fftn(rhs) + fft_rhs[0, 0, 0] = gauge # gauge invariance + sol = xp.fft.ifftn(fft_rhs / op).real return sol -def project_vector_field_divergence(vector_field, spacings=(1, 1, 1), xp=np): +def project_vector_field_divergence_periodic_3D(vector_field, xp=np): """ Returns solenoidal part of vector field using projection: f - \\grad{p} s.t. \\laplacian{p} = \\div{f} """ - - div_v = compute_divergence(vector_field, spacings, xp=xp) - p = preconditioned_poisson_solver(div_v, spacings[0], xp=xp) - grad_p = compute_gradient(p, spacings, xp=xp) + spacings = (1, 1, 1) + div_v = compute_divergence_periodic(vector_field, spacings, xp=xp) + p = preconditioned_poisson_solver_periodic_3D(div_v, xp=xp) + grad_p = compute_gradient_periodic(p, spacings, xp=xp) return vector_field - grad_p @@ -1610,25 +1575,226 @@ def aberrations_basis_function( return aberrations_basis, aberrations_mn +def interleave_ndarray_symmetrically(array_nd, axis, xp=np): + """[a,b,c,d,e,f] -> [a,c,e,f,d,b]""" + array_shape = np.array(array_nd.shape) + d = array_nd.ndim + n = array_shape[axis] + + array = xp.empty_like(array_nd) + array[array_slice(axis, d, None, (n - 1) // 2 + 1)] = array_nd[ + array_slice(axis, d, None, None, 2) + ] + + if n % 2: # odd + array[array_slice(axis, d, (n - 1) // 2 + 1, None)] = array_nd[ + array_slice(axis, d, -2, None, -2) + ] + else: # even + array[array_slice(axis, d, (n - 1) // 2 + 1, None)] = array_nd[ + array_slice(axis, d, None, None, -2) + ] + + return array + + +def return_exp_factors(size, ndim, axis, xp=np): + none_axes = [None] * ndim + none_axes[axis] = slice(None) + exp_factors = 2 * xp.exp(-1j * np.pi * xp.arange(size) / (2 * size)) + return exp_factors[tuple(none_axes)] + + +def dct_II_using_FFT_base(array_nd, xp=np): + """FFT-based DCT-II""" + d = array_nd.ndim + + for axis in range(d): + n = array_nd.shape[axis] + interleaved_array = interleave_ndarray_symmetrically(array_nd, axis=axis, xp=xp) + exp_factors = return_exp_factors(n, d, axis, xp) + interleaved_array = xp.fft.fft(interleaved_array, axis=axis) + interleaved_array *= exp_factors + array_nd = interleaved_array.real + + return array_nd + + +def dct_II_using_FFT(array_nd, xp=np): + if xp.iscomplexobj(array_nd): + real = dct_II_using_FFT_base(array_nd.real, xp=xp) + imag = dct_II_using_FFT_base(array_nd.imag, xp=xp) + return real + 1j * imag + else: + return dct_II_using_FFT_base(array_nd, xp=xp) + + +def interleave_ndarray_symmetrically_inverse(array_nd, axis, xp=np): + """[a,c,e,f,d,b] -> [a,b,c,d,e,f]""" + array_shape = np.array(array_nd.shape) + d = array_nd.ndim + n = array_shape[axis] + + array = xp.empty_like(array_nd) + array[array_slice(axis, d, None, None, 2)] = array_nd[ + array_slice(axis, d, None, (n - 1) // 2 + 1) + ] + + if n % 2: # odd + array[array_slice(axis, d, -2, None, -2)] = array_nd[ + array_slice(axis, d, (n - 1) // 2 + 1, None) + ] + else: # even + array[array_slice(axis, d, None, None, -2)] = array_nd[ + array_slice(axis, d, (n - 1) // 2 + 1, None) + ] + + return array + + +def return_exp_factors_inverse(size, ndim, axis, xp=np): + none_axes = [None] * ndim + none_axes[axis] = slice(None) + exp_factors = xp.exp(1j * np.pi * xp.arange(size) / (2 * size)) / 2 + return exp_factors[tuple(none_axes)] + + +def idct_II_using_FFT_base(array_nd, xp=np): + """FFT-based IDCT-II""" + d = array_nd.ndim + + for axis in range(d): + n = array_nd.shape[axis] + reversed_array = xp.roll( + array_nd[array_slice(axis, d, None, None, -1)], 1, axis=axis + ) # C(N-k) + reversed_array[array_slice(axis, d, 0, 1)] = 0 # set C(N) = 0 + + interleaved_array = array_nd - 1j * reversed_array + exp_factors = return_exp_factors_inverse(n, d, axis, xp) + interleaved_array *= exp_factors + + array_nd = xp.fft.ifft(interleaved_array, axis=axis).real + array_nd = interleave_ndarray_symmetrically_inverse(array_nd, axis=axis, xp=xp) + + return array_nd + + +def idct_II_using_FFT(array_nd, xp=np): + """FFT-based IDCT-II""" + if xp.iscomplexobj(array_nd): + real = idct_II_using_FFT_base(array_nd.real, xp=xp) + imag = idct_II_using_FFT_base(array_nd.imag, xp=xp) + return real + 1j * imag + else: + return idct_II_using_FFT_base(array_nd, xp=xp) + + +def preconditioned_laplacian_neumann_2D(shape, xp=np): + """DCT eigenvalues""" + n, m = shape + i, j = xp.ogrid[0:n, 0:m] + + op = 4 - 2 * xp.cos(np.pi * i / n) - 2 * xp.cos(np.pi * j / m) + op[0, 0] = 1 # gauge invariance + return -op + + +def preconditioned_poisson_solver_neumann_2D(rhs, gauge=None, xp=np): + """DCT based poisson solver""" + op = preconditioned_laplacian_neumann_2D(rhs.shape, xp=xp) + + if gauge is None: + gauge = xp.mean(rhs) + + if xp is np: + fft_rhs = dctn(rhs, type=2) + fft_rhs[0, 0] = gauge # gauge invariance + sol = idctn(fft_rhs / op, type=2).real + else: + fft_rhs = dct_II_using_FFT(rhs, xp) + fft_rhs[0, 0] = gauge # gauge invariance + sol = idct_II_using_FFT(fft_rhs / op, xp) + + return sol + + +def unwrap_phase_2d(array, weights=None, gauge=None, corner_centered=True, xp=np): + """Weigted phase unwrapping using DCT-based poisson solver""" + + if np.iscomplexobj(array): + raise ValueError() + + if corner_centered: + array = xp.fft.fftshift(array) + if weights is not None: + weights = xp.fft.fftshift(weights) + + dx = xp.mod(xp.diff(array, axis=0) + np.pi, 2 * np.pi) - np.pi + dy = xp.mod(xp.diff(array, axis=1) + np.pi, 2 * np.pi) - np.pi + + if weights is not None: + # normalize weights + weights -= weights.min() + weights /= weights.max() + + ww = weights**2 + dx *= xp.minimum(ww[:-1, :], ww[1:, :]) + dy *= xp.minimum(ww[:, :-1], ww[:, 1:]) + + rho = xp.diff(dx, axis=0, prepend=0, append=0) + rho += xp.diff(dy, axis=1, prepend=0, append=0) + + unwrapped_array = preconditioned_poisson_solver_neumann_2D(rho, gauge=gauge, xp=xp) + unwrapped_array -= unwrapped_array.min() + + if corner_centered: + unwrapped_array = xp.fft.ifftshift(unwrapped_array) + + return unwrapped_array + + +def unwrap_phase_2d_skimage(array, corner_centered=True, xp=np): + if xp is np: + array = array.astype(np.float64) + unwrapped_array = unwrap_phase(array, wrap_around=corner_centered).astype( + xp.float32 + ) + else: + array = xp.asnumpy(array).astype(np.float64) + unwrapped_array = unwrap_phase(array, wrap_around=corner_centered) + unwrapped_array = xp.asarray(unwrapped_array).astype(xp.float32) + + return unwrapped_array + + def fit_aberration_surface( complex_probe, probe_sampling, energy, max_angular_order, max_radial_order, + use_scikit_image, xp=np, ): """ """ probe_amp = xp.abs(complex_probe) probe_angle = -xp.angle(complex_probe) - if xp is np: - probe_angle = probe_angle.astype(np.float64) - unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True).astype(xp.float32) + if use_scikit_image: + unwrapped_angle = unwrap_phase_2d_skimage( + probe_angle, + corner_centered=True, + xp=xp, + ) + else: - probe_angle = xp.asnumpy(probe_angle).astype(np.float64) - unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) - unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) + unwrapped_angle = unwrap_phase_2d( + probe_angle, + weights=probe_amp, + corner_centered=True, + xp=xp, + ) raveled_basis, _ = aberrations_basis_function( complex_probe.shape, @@ -1646,6 +1812,8 @@ def fit_aberration_surface( coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) + angle_offset = fitted_angle[0, 0] - probe_angle[0, 0] + fitted_angle -= angle_offset return fitted_angle, coeff @@ -1674,3 +1842,685 @@ def rotate_point(origin, point, angle): qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy) qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy) return qx, qy + + +def bilinearly_interpolate_array( + image, + xa, + ya, + xp=np, +): + """ + Bilinear sampling of intensities from an image array and pixel positions. + + Parameters + ---------- + image: np.ndarray + Image array to sample from + xa: np.ndarray + Vertical interpolation sampling positions of image array in pixels + ya: np.ndarray + Horizontal interpolation sampling positions of image array in pixels + + Returns + ------- + intensities: np.ndarray + Bilinearly-sampled intensities of array at (xa,ya) positions + + """ + + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + all_inds = [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + + all_weights = [ + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ] + + raveled_image = image.ravel() + intensities = xp.zeros(xa.shape, dtype=xp.float32) + # filter_weights = xp.zeros(xa.shape, dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + intensities += ( + raveled_image[ + xp.ravel_multi_index( + inds, + image.shape, + mode=["wrap", "wrap"], + ) + ] + * weights + ) + # filter_weights += weights + + return intensities # / filter_weights # unnecessary, sums up to unity + + +def lanczos_interpolate_array( + image, + xa, + ya, + alpha, + xp=np, +): + """ + Lanczos sampling of intensities from an image array and pixel positions. + + Parameters + ---------- + image: np.ndarray + Image array to sample from + xa: np.ndarray + Vertical Interpolation sampling positions of image array in pixels + ya: np.ndarray + Horizontal interpolation sampling positions of image array in pixels + alpha: int + Lanczos kernel order + + Returns + ------- + intensities: np.ndarray + Lanczos-sampled intensities of array at (xa,ya) positions + + """ + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + all_inds = [] + all_weights = [] + + for i in range(-alpha + 1, alpha + 1): + for j in range(-alpha + 1, alpha + 1): + all_inds.append([xF + i, yF + j]) + all_weights.append( + (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) + * (xp.sinc(j - dy) * xp.sinc((i - dy) / alpha)) + ) + + raveled_image = image.ravel() + intensities = xp.zeros(xa.shape, dtype=xp.float32) + filter_weights = xp.zeros(xa.shape, dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + intensities += ( + raveled_image[ + xp.ravel_multi_index( + inds, + image.shape, + mode=["wrap", "wrap"], + ) + ] + * weights + ) + filter_weights += weights + + return intensities / filter_weights + + +def pixel_rolling_kernel_density_estimate( + stack, + shifts, + upsampling_factor, + kde_sigma, + lowpass_filter=False, + xp=np, + gaussian_filter=gaussian_filter, +): + """ + kernel density estimate from a set coordinates (xa,ya) and intensity weights. + + Parameters + ---------- + stack: np.ndarray + Unshifted image stack, shape (N,P,S) + shifts: np.ndarray + Shifts for each image in stack, shape: (N,2) + upsampling_factor: int + Upsampling factor + kde_sigma: float + KDE gaussian kernel bandwidth in upsampled pixels + lowpass_filter: bool, optional + If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function + + Returns + ------- + pix_output: np.ndarray + Upsampled intensity image + """ + upsampled_shape = np.array(stack.shape) + upsampled_shape *= (1, upsampling_factor, upsampling_factor) + + upsampled_shifts = shifts * upsampling_factor + upsampled_shifts_int = xp.modf(upsampled_shifts)[-1].astype("int") + + upsampled_stack = xp.zeros(upsampled_shape, dtype=xp.float32) + upsampled_stack[..., ::upsampling_factor, ::upsampling_factor] = stack + pix_output = xp.zeros(upsampled_shape[-2:], dtype=xp.float32) + + for BF_index in range(upsampled_stack.shape[0]): + shift = upsampled_shifts_int[BF_index] + pix_output += xp.roll(upsampled_stack[BF_index], shift, axis=(0, 1)) + + upsampled_stack[..., ::upsampling_factor, ::upsampling_factor] = 1 + pix_count = xp.zeros(upsampled_shape[-2:], dtype=xp.float32) + + # sequential looping for memory reasons + for BF_index in range(upsampled_stack.shape[0]): + shift = upsampled_shifts_int[BF_index] + pix_count += xp.roll(upsampled_stack[BF_index], shift, axis=(0, 1)) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma) + pix_output = gaussian_filter(pix_output, kde_sigma) + + sub = pix_count > 1e-3 + pix_output[sub] /= pix_count[sub] + pix_output[np.logical_not(sub)] = 1 + + if lowpass_filter: + pix_fft = xp.fft.fft2(pix_output) + pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None] + pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None] + pix_output = xp.real(xp.fft.ifft2(pix_fft)) + + return pix_output + + +def bilinear_kernel_density_estimate( + xa, + ya, + intensities, + output_shape, + kde_sigma, + lowpass_filter=False, + xp=np, + gaussian_filter=gaussian_filter, +): + """ + kernel density estimate from a set coordinates (xa,ya) and intensity weights. + + Parameters + ---------- + xa: np.ndarray + Vertical positions of intensity array in pixels + ya: np.ndarray + Horizontal positions of intensity array in pixels + intensities: np.ndarray + Intensity array weights + output_shape: (int,int) + Upsampled intensities shape + kde_sigma: float + KDE gaussian kernel bandwidth in upsampled pixels + lowpass_filter: bool, optional + If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function + + Returns + ------- + pix_output: np.ndarray + Upsampled intensity image + """ + + # interpolation + xF = xp.floor(xa.ravel()).astype("int") + yF = xp.floor(ya.ravel()).astype("int") + dx = xa.ravel() - xF + dy = ya.ravel() - yF + + all_inds = [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + + all_weights = [ + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ] + + raveled_intensities = intensities.ravel() + pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) + pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + inds_1D = xp.ravel_multi_index( + inds, + output_shape, + mode=["wrap", "wrap"], + ) + + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=np.prod(output_shape), + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * raveled_intensities, + minlength=np.prod(output_shape), + ) + + # reshape 1D arrays to 2D + pix_count = xp.reshape( + pix_count, + output_shape, + ) + pix_output = xp.reshape( + pix_output, + output_shape, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma) + pix_output = gaussian_filter(pix_output, kde_sigma) + sub = pix_count > 1e-3 + pix_output[sub] /= pix_count[sub] + pix_output[np.logical_not(sub)] = 1 + + if lowpass_filter: + pix_fft = xp.fft.fft2(pix_output) + pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None] + pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None] + pix_output = xp.real(xp.fft.ifft2(pix_fft)) + + return pix_output + + +def lanczos_kernel_density_estimate( + xa, + ya, + intensities, + output_shape, + kde_sigma, + alpha, + lowpass_filter=False, + xp=np, + gaussian_filter=gaussian_filter, +): + """ + kernel density estimate from a set coordinates (xa,ya) and intensity weights. + + Parameters + ---------- + xa: np.ndarray + Vertical positions of intensity array in pixels + ya: np.ndarray + Horizontal positions of intensity array in pixels + intensities: np.ndarray + Intensity array weights + output_shape: (int,int) + Upsampled intensities shape + kde_sigma: float + KDE gaussian kernel bandwidth in upsampled pixels + alpha: int + Lanczos kernel order + lowpass_filter: bool, optional + If True, the resulting KDE upsampled image is lowpass-filtered using a sinc-function + + Returns + ------- + pix_output: np.ndarray + Upsampled intensity image + """ + + # interpolation + xF = xp.floor(xa.ravel()).astype("int") + yF = xp.floor(ya.ravel()).astype("int") + dx = xa.ravel() - xF + dy = ya.ravel() - yF + + all_inds = [] + all_weights = [] + + for i in range(-alpha + 1, alpha + 1): + for j in range(-alpha + 1, alpha + 1): + all_inds.append([xF + i, yF + j]) + all_weights.append( + (xp.sinc(i - dx) * xp.sinc((i - dx) / alpha)) + * (xp.sinc(j - dy) * xp.sinc((i - dy) / alpha)) + ) + + raveled_intensities = intensities.ravel() + pix_count = xp.zeros(np.prod(output_shape), dtype=xp.float32) + pix_output = xp.zeros(np.prod(output_shape), dtype=xp.float32) + + for inds, weights in zip(all_inds, all_weights): + inds_1D = xp.ravel_multi_index( + inds, + output_shape, + mode=["wrap", "wrap"], + ) + + pix_count += xp.bincount( + inds_1D, + weights=weights, + minlength=np.prod(output_shape), + ) + pix_output += xp.bincount( + inds_1D, + weights=weights * raveled_intensities, + minlength=np.prod(output_shape), + ) + + # reshape 1D arrays to 2D + pix_count = xp.reshape( + pix_count, + output_shape, + ) + pix_output = xp.reshape( + pix_output, + output_shape, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma) + pix_output = gaussian_filter(pix_output, kde_sigma) + sub = pix_count > 1e-3 + pix_output[sub] /= pix_count[sub] + pix_output[np.logical_not(sub)] = 1 + + if lowpass_filter: + pix_fft = xp.fft.fft2(pix_output) + pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None] + pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None] + pix_output = xp.real(xp.fft.ifft2(pix_fft)) + + return pix_output + + +def bilinear_resample( + array, + scale=None, + output_size=None, + mode="grid-wrap", + grid_mode=True, + vectorized=True, + conserve_array_sums=False, + xp=np, +): + """ + Resize an array along its final two axes. + Note, this is vectorized by default and thus very memory-intensive. + + The scaling of the array can be specified by passing either `scale`, which sets + the scaling factor along both axes to be scaled; or by passing `output_size`, + which specifies the final dimensions of the scaled axes. + + Parameters + ---------- + array: np.ndarray + Input array to be resampled + scale: float + Scalar value giving the scaling factor for all dimensions + output_size: (int,int) + Tuple of two values giving the output size for the final two axes + xp: Callable + Array computing module + + Returns + ------- + resampled_array: np.ndarray + Resampled array + """ + + array_size = np.array(array.shape) + input_size = array_size[-2:].copy() + + if scale is not None: + scale = np.array(scale) + if scale.size == 1: + scale = np.tile(scale, 2) + + output_size = (input_size * scale).astype("int") + else: + if output_size is None: + raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) + if output_size.size != 2: + raise ValueError("`output_size` must contain exactly two values.") + output_size = np.array(output_size) + + scale_output = tuple(output_size / input_size) + scale_output = (1,) * (array_size.size - input_size.size) + scale_output + + if xp is np: + zoom_xp = zoom + else: + zoom_xp = zoom_cp + + if vectorized: + array = zoom_xp(array, scale_output, order=1, mode=mode, grid_mode=grid_mode) + else: + flat_array = array.reshape((-1,) + tuple(input_size)) + out_array = xp.zeros( + (flat_array.shape[0],) + tuple(output_size), flat_array.dtype + ) + for idx in range(flat_array.shape[0]): + out_array[idx] = zoom_xp( + flat_array[idx], + scale_output[-2:], + order=1, + mode=mode, + grid_mode=grid_mode, + ) + + array = out_array.reshape(tuple(array_size[:-2]) + tuple(output_size)) + + if conserve_array_sums: + array = array / np.array(scale_output).prod() + + return array + + +def vectorized_fourier_resample( + array, + scale=None, + output_size=None, + conserve_array_sums=False, + xp=np, +): + """ + Resize a 2D array along any dimension, using Fourier interpolation. + For 4D input arrays, only the final two axes can be resized. + Note, this is vectorized and thus very memory-intensive. + + The scaling of the array can be specified by passing either `scale`, which sets + the scaling factor along both axes to be scaled; or by passing `output_size`, + which specifies the final dimensions of the scaled axes (and allows for different + scaling along the x,y or kx,ky axes.) + + Parameters + ---------- + array: np.ndarray + Input 2D/4D array to be resampled + scale: float + Scalar value giving the scaling factor for all dimensions + output_size: (int,int) + Tuple of two values giving eith the (x,y) or (kx,ky) output size for 2D and 4D respectively. + xp: Callable + Array computing module + + Returns + ------- + resampled_array: np.ndarray + Resampled 2D/4D array + """ + + array_size = np.array(array.shape) + input_size = array_size[-2:].copy() + + if scale is not None: + scale = np.array(scale) + if scale.size == 1: + scale = np.tile(scale, 2) + + output_size = (input_size * scale).astype("int") + else: + if output_size is None: + raise ValueError("One of `scale` or `output_size` must be provided.") + output_size = np.array(output_size) + if output_size.size != 2: + raise ValueError("`output_size` must contain exactly two values.") + output_size = np.array(output_size) + + scale_output = np.prod(output_size) / np.prod(input_size) + + # x slices + if output_size[0] > input_size[0]: + # x dimension increases + x0 = (input_size[0] + 1) // 2 + x1 = input_size[0] // 2 + + x_ul_out = slice(0, x0) + x_ul_in_ = slice(0, x0) + + x_ll_out = slice(0 - x1 + output_size[0], output_size[0]) + x_ll_in_ = slice(0 - x1 + input_size[0], input_size[0]) + + x_ur_out = slice(0, x0) + x_ur_in_ = slice(0, x0) + + x_lr_out = slice(0 - x1 + output_size[0], output_size[0]) + x_lr_in_ = slice(0 - x1 + input_size[0], input_size[0]) + + elif output_size[0] < input_size[0]: + # x dimension decreases + x0 = (output_size[0] + 1) // 2 + x1 = output_size[0] // 2 + + x_ul_out = slice(0, x0) + x_ul_in_ = slice(0, x0) + + x_ll_out = slice(0 - x1 + output_size[0], output_size[0]) + x_ll_in_ = slice(0 - x1 + input_size[0], input_size[0]) + + x_ur_out = slice(0, x0) + x_ur_in_ = slice(0, x0) + + x_lr_out = slice(0 - x1 + output_size[0], output_size[0]) + x_lr_in_ = slice(0 - x1 + input_size[0], input_size[0]) + + else: + # x dimension does not change + x_ul_out = slice(None) + x_ul_in_ = slice(None) + + x_ll_out = slice(None) + x_ll_in_ = slice(None) + + x_ur_out = slice(None) + x_ur_in_ = slice(None) + + x_lr_out = slice(None) + x_lr_in_ = slice(None) + + # y slices + if output_size[1] > input_size[1]: + # y increases + y0 = (input_size[1] + 1) // 2 + y1 = input_size[1] // 2 + + y_ul_out = slice(0, y0) + y_ul_in_ = slice(0, y0) + + y_ll_out = slice(0, y0) + y_ll_in_ = slice(0, y0) + + y_ur_out = slice(0 - y1 + output_size[1], output_size[1]) + y_ur_in_ = slice(0 - y1 + input_size[1], input_size[1]) + + y_lr_out = slice(0 - y1 + output_size[1], output_size[1]) + y_lr_in_ = slice(0 - y1 + input_size[1], input_size[1]) + + elif output_size[1] < input_size[1]: + # y decreases + y0 = (output_size[1] + 1) // 2 + y1 = output_size[1] // 2 + + y_ul_out = slice(0, y0) + y_ul_in_ = slice(0, y0) + + y_ll_out = slice(0, y0) + y_ll_in_ = slice(0, y0) + + y_ur_out = slice(0 - y1 + output_size[1], output_size[1]) + y_ur_in_ = slice(0 - y1 + input_size[1], input_size[1]) + + y_lr_out = slice(0 - y1 + output_size[1], output_size[1]) + y_lr_in_ = slice(0 - y1 + input_size[1], input_size[1]) + + else: + # y dimension does not change + y_ul_out = slice(None) + y_ul_in_ = slice(None) + + y_ll_out = slice(None) + y_ll_in_ = slice(None) + + y_ur_out = slice(None) + y_ur_in_ = slice(None) + + y_lr_out = slice(None) + y_lr_in_ = slice(None) + + # image array + array_size[-2:] = output_size + array_resize = xp.zeros(array_size, dtype=xp.complex64) + array_fft = xp.fft.fft2(array) + + # copy each quadrant into the resize array + array_resize[..., x_ul_out, y_ul_out] = array_fft[..., x_ul_in_, y_ul_in_] + array_resize[..., x_ll_out, y_ll_out] = array_fft[..., x_ll_in_, y_ll_in_] + array_resize[..., x_ur_out, y_ur_out] = array_fft[..., x_ur_in_, y_ur_in_] + array_resize[..., x_lr_out, y_lr_out] = array_fft[..., x_lr_in_, y_lr_in_] + + # Back to real space + array_resize = xp.real(xp.fft.ifft2(array_resize)).astype(xp.float32) + + # Normalization + if not conserve_array_sums: + array_resize = array_resize * scale_output + + return array_resize + + +def partition_list(lst, size): + """Partitions lst into chunks of size. Returns a generator.""" + for i in range(0, len(lst), size): + yield lst[i : i + size] + + +def copy_to_device(array, device="cpu"): + """Copies array to device. Default allows one to use this as asnumpy()""" + xp = get_array_module(array) + + if xp is np: + if device == "cpu": + return np.asarray(array) + elif device == "gpu": + return cp.asarray(array) + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + else: + if device == "cpu": + return cp.asnumpy(array) + elif device == "gpu": + return cp.asarray(array) + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 7df00a235..c62d1258d 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -584,8 +584,12 @@ def plot_pdf( def calculate_FEM_local( self, - figsize=(8, 6), + use_median=False, + plot_normalized_variance=True, + figsize=(8, 4), + return_values=False, returnfig=False, + progress_bar=True, ): """ Calculate fluctuation electron microscopy (FEM) statistics, including radial mean, @@ -596,18 +600,293 @@ def calculate_FEM_local( -------- self: PolarDatacube Polar datacube used for measuring FEM properties. + use_median: Bool + Use median instead of mean for statistics. Returns -------- - radial_avg: np.array - Average radial intensity - radial_var: np.array - Variance in the radial dimension + local_radial_mean: np.array + Average radial intensity of each probe position + local_radial_var: np.array + Variance in the radial dimension of each probe position + """ + # init radial data arrays + self.local_radial_mean = np.zeros( + ( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + ) + ) + self.local_radial_var = np.zeros( + ( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + ) + ) + + # Compute the radial mean and standard deviation for each probe position + for rx, ry in tqdmnd( + self._datacube.shape[0], + self._datacube.shape[1], + desc="Radial statistics", + unit=" probe positions", + disable=not progress_bar, + ): + im = self.data[rx, ry] + + if use_median: + im_mean = np.ma.median(im, axis=0) + im_var = np.ma.median((im - im_mean) ** 2, axis=0) + else: + im_mean = np.ma.mean(im, axis=0) + im_var = np.ma.mean((im - im_mean) ** 2, axis=0) + + self.local_radial_mean[rx, ry] = im_mean + self.local_radial_var[rx, ry] = im_var + + if plot_normalized_variance: + fig, ax = plt.subplots(figsize=figsize) + + sig = self.local_radial_var / self.local_radial_mean**2 + if use_median: + sig_plot = np.median(sig, axis=(0, 1)) + else: + sig_plot = np.mean(sig, axis=(0, 1)) + + ax.plot( + self.qq, + sig_plot, + ) + ax.set_xlabel( + "Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")" + ) + ax.set_ylabel("Normalized Variance") + ax.set_xlim((self.qq[0], self.qq[-1])) + + if return_values: + if returnfig: + return self.local_radial_mean, self.local_radial_var, fig, ax + else: + return self.local_radial_mean, self.local_radial_var + else: + if returnfig: + return fig, ax + + +def calculate_annular_symmetry( + self, + max_symmetry=12, + mask_realspace=None, + plot_result=False, + figsize=(8, 4), + return_symmetry_map=False, + progress_bar=True, +): """ + This function calculates radial symmetry of diffraction patterns, typically applied + to amorphous scattering, but it can also be used for crystalline Bragg diffraction. - pass + Parameters + -------- + self: PolarDatacube + Polar transformed datacube + max_symmetry: int + Symmetry orders will be computed from 1 to max_symmetry for n-fold symmetry orders. + mask_realspace: np.array + Boolean mask, symmetries will only be computed at probe positions where mask is True. + plot_result: bool + Plot the resulting array + figsize: (float, float) + Size of the plot. + return_symmetry_map: bool + Set to true to return the symmetry array. + progress_bar: bool + Show progress bar during calculation. + + Returns + -------- + annular_symmetry: np.array + Array with annular symmetry magnitudes, with shape [max_symmetry, num_radial_bins] + + """ + + # Initialize outputs + self.annular_symmetry_max = max_symmetry + self.annular_symmetry = np.zeros( + ( + self.data_raw.shape[0], + self.data_raw.shape[1], + max_symmetry, + self.polar_shape[1], + ) + ) + + # Loop over all probe positions + for rx, ry in tqdmnd( + self._datacube.shape[0], + self._datacube.shape[1], + desc="Annular symmetry", + unit=" probe positions", + disable=not progress_bar, + ): + # Get polar transformed image + im = self.transform( + self.data_raw.data[rx, ry], + ) + polar_im = np.ma.getdata(im) + polar_mask = np.ma.getmask(im) + polar_im[polar_mask] = 0 + polar_mask = np.logical_not(polar_mask) + + # Calculate normalized correlation of polar image along angular direction (axis = 0) + polar_corr = np.real( + np.fft.ifft( + np.abs( + np.fft.fft( + polar_im, + axis=0, + ) + ) + ** 2, + axis=0, + ), + ) + polar_corr_norm = ( + np.sum( + polar_im, + axis=0, + ) + ** 2 + ) + sub = polar_corr_norm > 0 + polar_corr[:, sub] /= polar_corr_norm[ + sub + ] # gets rid of divide by 0 (False near center) + polar_corr[:, sub] -= 1 + + # Calculate normalized correlation of polar mask along angular direction (axis = 0) + mask_corr = np.real( + np.fft.ifft( + np.abs( + np.fft.fft( + polar_mask.astype("float"), + axis=0, + ) + ) + ** 2, + axis=0, + ), + ) + mask_corr_norm = ( + np.sum( + polar_mask.astype("float"), + axis=0, + ) + ** 2 + ) + sub = mask_corr_norm > 0 + mask_corr[:, sub] /= mask_corr_norm[ + sub + ] # gets rid of divide by 0 (False near center) + mask_corr[:, sub] -= 1 + + # Normalize polar correlation by mask correlation (beam stop removal) + sub = np.abs(mask_corr) > 0 + polar_corr[sub] -= mask_corr[sub] + + # Measure symmetry + self.annular_symmetry[rx, ry, :, :] = np.abs(np.fft.fft(polar_corr, axis=0))[ + 1 : max_symmetry + 1 + ] + + if plot_result: + fig, ax = plt.subplots(figsize=figsize) + ax.imshow( + np.mean(self.annular_symmetry, axis=(0, 1)), + aspect="auto", + extent=[ + self.qq[0], + self.qq[-1], + max_symmetry, + 0, + ], + ) + ax.set_yticks( + np.arange(max_symmetry) + 0.5, + range(1, max_symmetry + 1), + ) + ax.set_xlabel("Scattering angle (1/Å)") + ax.set_ylabel("Symmetry Order") + + if return_symmetry_map: + return self.annular_symmetry + + +def plot_annular_symmetry( + self, + symmetry_orders=None, + plot_std=False, + normalize_by_mean=False, + cmap="turbo", + vmin=0.01, + vmax=0.99, + figsize=(8, 4), +): + """ + Plot the symmetry orders + """ + + if symmetry_orders is None: + symmetry_orders = np.arange(1, self.annular_symmetry_max + 1) + else: + symmetry_orders = np.array(symmetry_orders) + + # plotting image + if plot_std: + im_plot = np.std( + self.annular_symmetry, + axis=(0, 1), + )[symmetry_orders - 1, :] + else: + im_plot = np.mean( + self.annular_symmetry, + axis=(0, 1), + )[symmetry_orders - 1, :] + if normalize_by_mean: + im_plot /= self.radial_mean[None, :] + + # plotting range + int_vals = np.sort(im_plot.ravel()) + ind0 = np.clip(np.round(im_plot.size * vmin).astype("int"), 0, im_plot.size - 1) + ind1 = np.clip(np.round(im_plot.size * vmax).astype("int"), 0, im_plot.size - 1) + vmin = int_vals[ind0] + vmax = int_vals[ind1] + + # plot image + fig, ax = plt.subplots(figsize=figsize) + ax.imshow( + im_plot, + aspect="auto", + extent=[ + self.qq[0], + self.qq[-1], + np.max(symmetry_orders), + np.min(symmetry_orders) - 1, + ], + cmap=cmap, + vmin=vmin, + vmax=vmax, + ) + ax.set_yticks( + symmetry_orders - 0.5, + symmetry_orders, + ) + ax.set_xlabel("Scattering angle (1/A)") + ax.set_ylabel("Symmetry Order") def scattering_model(k2, *coefs): diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index 4c06d4c09..0ba284ada 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -4,7 +4,6 @@ class PolarDatacube: - """ An interface to a 4D-STEM datacube under polar-elliptical transformation. """ @@ -97,8 +96,10 @@ def __init__( calculate_radial_statistics, calculate_pair_dist_function, calculate_FEM_local, + calculate_annular_symmetry, plot_radial_mean, plot_radial_var_norm, + plot_annular_symmetry, plot_background_fits, plot_sf_estimate, plot_reduced_pdf, diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index 3e39c5584..18f059b61 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -3,15 +3,18 @@ # from scipy.optimize import leastsq from scipy.optimize import curve_fit +from emdfile import tqdmnd def fit_amorphous_ring( - im, + im=None, + datacube=None, center=None, radial_range=None, coefs=None, mask_dp=None, show_fit_mask=False, + fit_all_images=False, maxfev=None, verbose=False, plot_result=True, @@ -19,6 +22,7 @@ def fit_amorphous_ring( plot_int_scale=(-3, 3), figsize=(8, 8), return_all_coefs=True, + progress_bar=None, ): """ Fit an amorphous halo with a two-sided Gaussian model, plus a background @@ -28,6 +32,8 @@ def fit_amorphous_ring( -------- im: np.array 2D image array to perform fitting on + datacube: py4DSTEM.DataCube + datacube to perform the fitting on center: np.array (x,y) center coordinates for fitting mask. If not specified by the user, we will assume the center coordinate is (im.shape-1)/2. @@ -40,6 +46,8 @@ def fit_amorphous_ring( Dark field mask for fitting, in addition to the radial range specified above. show_fit_mask: bool Set to true to preview the fitting mask and initial guess for the ellipse params + fit_all_images: bool + Fit the elliptic parameters to all images maxfev: int Max number of fitting evaluations for curve_fit. verbose: bool @@ -63,6 +71,12 @@ def fit_amorphous_ring( 11 parameter elliptic fit coefficients """ + # If passing in a DataCube, use mean diffraction pattern for initial guess + if im is None: + im = datacube.get_dp_mean() + if progress_bar is None: + progress_bar = True + # Default values if center is None: center = np.array(((im.shape[0] - 1) / 2, (im.shape[1] - 1) / 2)) @@ -193,7 +207,44 @@ def fit_amorphous_ring( )[0] coefs[4] = np.mod(coefs[4], 2 * np.pi) coefs[5:8] *= int_mean - # bounds=bounds + + # Perform the fit on each individual diffration pattern + if fit_all_images: + coefs_all = np.zeros((datacube.shape[0], datacube.shape[1], coefs.size)) + + for rx, ry in tqdmnd( + datacube.shape[0], + datacube.shape[1], + desc="Radial statistics", + unit=" probe positions", + disable=not progress_bar, + ): + vals = datacube.data[rx, ry][mask] + int_mean = np.mean(vals) + + if maxfev is None: + coefs_single = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + )[0] + else: + coefs_single = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + maxfev=maxfev, + )[0] + coefs_single[4] = np.mod(coefs_single[4], 2 * np.pi) + coefs_single[5:8] *= int_mean + + coefs_all[rx, ry] = coefs_single if verbose: print("x0 = " + str(np.round(coefs[0], 3)) + " px") @@ -214,9 +265,15 @@ def fit_amorphous_ring( # Return fit parameters if return_all_coefs: - return coefs + if fit_all_images: + return coefs_all + else: + return coefs else: - return coefs[:5] + if fit_all_images: + return coefs_all[:, :, :5] + else: + return coefs[:5] def plot_amorphous_ring( diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 4064fccaf..e6040edfb 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -125,6 +125,7 @@ def find_peaks_single_pattern( self._datacube.data[x, y], mask=mask, returnval="all_zeros", + origin=(self.calibration.qx0[x, y], self.calibration.qy0[x, y]), ) # Change sign convention of mask mask_bool = np.logical_not(mask_bool) diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py index dcff91709..ff5f6faf0 100644 --- a/py4DSTEM/process/strain/latticevectors.py +++ b/py4DSTEM/process/strain/latticevectors.py @@ -33,7 +33,7 @@ def index_bragg_directions(x0, y0, gx, gy, g1, g2): * **k**: *(ndarray of ints)* second index of the bragg directions * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y - coords 'h' and 'k' contain h and k. + coords 'g1_ind' and 'g2_ind' contain g1_ind and g2_ind. """ # Get beta, the matrix of lattice vectors beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) @@ -45,39 +45,39 @@ def index_bragg_directions(x0, y0, gx, gy, g1, g2): M = lstsq(beta, alpha, rcond=None)[0].T M = np.round(M).astype(int) - # Get h,k - h = M[:, 0] - k = M[:, 1] + # Get g1_ind,g2_ind + g1_ind = M[:, 0] + g2_ind = M[:, 1] # Store in a PointList - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] + coords = [("qx", float), ("qy", float), ("g1_ind", int), ("g2_ind", int)] temp_array = np.zeros([], dtype=coords) bragg_directions = PointList(data=temp_array) - bragg_directions.add_data_by_field((gx, gy, h, k)) + bragg_directions.add_data_by_field((gx, gy, g1_ind, g2_ind)) mask = np.zeros(bragg_directions["qx"].shape[0]) mask[0] = 1 bragg_directions.remove(mask) - return h, k, bragg_directions + return g1_ind, g2_ind, bragg_directions def add_indices_to_braggvectors( braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None ): """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, + Using the peak positions (qx,qy) and indices (g1_ind,g2_ind) in the PointList lattice, identify the indices for each peak in the PointListArray braggpeaks. Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed + three additional data columns -- 'g1_ind','g2_ind', and 'index_mask' -- specifying the peak + indices with the ints (g1_ind,g2_ind) and indicating whether the peak was successfully indexed or not with the bool index_mask. If `mask` is specified, only the locations where mask is True are indexed. Args: braggpeaks (PointListArray): the braggpeaks to index. Must contain the coordinates 'qx', 'qy', and 'intensity' - lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. - Must contain the coordinates 'qx', 'qy', 'h', and 'k' + lattice (PointList): the positions (qx,qy) of the (g1_ind,g2_ind) lattice points. + Must contain the coordinates 'qx', 'qy', 'g1_ind', and 'g2_ind' maxPeakSpacing (float): Maximum distance from the ideal lattice points to include a peak for indexing qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList @@ -88,7 +88,7 @@ def add_indices_to_braggvectors( Returns: (PointListArray): The original braggpeaks pointlistarray, with new coordinates - 'h', 'k', containing the indices of each indexable peak. + 'g1_ind', 'g2_ind', containing the indices of each indexable peak. """ # assert isinstance(braggpeaks,BraggVectors) @@ -107,8 +107,8 @@ def add_indices_to_braggvectors( ("qx", float), ("qy", float), ("intensity", float), - ("h", int), - ("k", int), + ("g1_ind", int), + ("g2_ind", int), ] indexed_braggpeaks = PointListArray( @@ -140,8 +140,8 @@ def add_indices_to_braggvectors( pl.data["qx"][i], pl.data["qy"][i], pl.data["intensity"][i], - lattice.data["h"][ind], - lattice.data["k"][ind], + lattice.data["g1_ind"][ind], + lattice.data["g2_ind"][ind], ) ) @@ -150,12 +150,12 @@ def add_indices_to_braggvectors( def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): """ - Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. + Fits lattice vectors g1,g2 to braggpeaks given some known (g1_ind,g2_ind) indexing. Args: braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also + weighting factor when fitting), 'g1_ind','g2_ind' (indexing). May optionally also contain 'index_mask' (bool), indicating which peaks have been successfully indixed and should be used. x0 (float): x-coord of the origin @@ -176,7 +176,10 @@ def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): """ assert isinstance(braggpeaks, PointList) assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + [ + name in braggpeaks.dtype.names + for name in ("qx", "qy", "intensity", "g1_ind", "g2_ind") + ] ) braggpeaks = braggpeaks.copy() @@ -189,9 +192,9 @@ def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): if braggpeaks.length < minNumPeaks: return None, None, None, None, None, None, None - # Get M, the matrix of (h,k) indices - h, k = braggpeaks.data["h"], braggpeaks.data["k"] - M = np.vstack((np.ones_like(h, dtype=int), h, k)).T + # Get M, the matrix of (g1_ind,g2_ind) indices + g1_ind, g2_ind = braggpeaks.data["g1_ind"], braggpeaks.data["g2_ind"] + M = np.vstack((np.ones_like(g1_ind, dtype=int), g1_ind, g2_ind)).T # Get alpha, the matrix of measured Bragg peak positions alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T @@ -223,7 +226,7 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): Args: braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also + weighting factor when fitting), 'g1_ind','g2_ind' (indexing). May optionally also contain 'index_mask' (bool), indicating which peaks have been successfully indixed and should be used. x0 (float): x-coord of the origin @@ -246,7 +249,10 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): """ assert isinstance(braggpeaks, PointListArray) assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + [ + name in braggpeaks.dtype.names + for name in ("qx", "qy", "intensity", "g1_ind", "g2_ind") + ] ) # Make RealSlice to contain outputs @@ -272,7 +278,8 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): # Store data if g1x is not None: g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 + # Assume this is a correct change + g1g2_map.get_slice("y0").data[Rx, Ry] = qy0 g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 25162180f..516820e04 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -1,7 +1,7 @@ # Defines the Strain class import warnings -from typing import Optional +from typing import Optional, List, Tuple, Union import matplotlib.pyplot as plt from matplotlib.patches import Circle @@ -386,6 +386,78 @@ def choose_basis_vectors( else: return + def set_hkl( + self, + g1_hkl: Union[List[int], Tuple[int, int, int], np.ndarray[np.int64]], + g2_hkl: Union[List[int], Tuple[int, int, int], np.ndarray[np.int64]], + ): + """ + calculate the [h,k,l] reflections from the `g1_ind`,`g2_ind` from known 'g1_hkl` and 'g2_hkl' reflections. + Creates 'bragg_vectors_indexed_hkl' attribute + Args: + g1_hkl (list[int] | tuple[int,int,int] | np.ndarray[int]): known [h,k,l] reflection for g1_vector + g2_hkl (list[int] | tuple[int,int,int] | np.ndarray[int]): known [h,k,l] reflection for g1_vector + """ + + g1_hkl = np.array(g1_hkl) + g2_hkl = np.array(g2_hkl) + + # Initialize a PLA + bvs_hkl = PointListArray( + shape=self.shape, + dtype=[ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ("l", int), + ], + ) + # loop over the probe posistions + for Rx, Ry in tqdmnd( + self.shape[0], + self.shape[1], + desc="Converting (g1_ind,g2_ind) to (h,k,l)", + unit="DP", + unit_scale=True, + ): + # get a single indexed + braggvectors_indexed_dp = self.bragg_vectors_indexed[Rx, Ry] + + # make a Pointlsit + bvs_hkl_curr = PointList( + data=np.empty(len(braggvectors_indexed_dp), dtype=bvs_hkl.dtype) + ) + # populate qx, qy and intensity fields + bvs_hkl_curr.data["qx"] = braggvectors_indexed_dp["qx"] + bvs_hkl_curr.data["qy"] = braggvectors_indexed_dp["qy"] + bvs_hkl_curr.data["intensity"] = braggvectors_indexed_dp["intensity"] + + # calcuate the hkl vectors + vectors_hkl = ( + g1_hkl[:, np.newaxis] * braggvectors_indexed_dp["g1_ind"] + + g2_hkl[:, np.newaxis] * braggvectors_indexed_dp["g2_ind"] + ) + # self.vectors_hkl = vectors_hkl + + # populate h,k,l fields + # print(vectors_hkl.shape) + # bvs_hkl_curr.data['h'] = vectors_hkl[0,:] + # bvs_hkl_curr.data['k'] = vectors_hkl[1,:] + # bvs_hkl_curr.data['l'] = vectors_hkl[2,:] + ( + bvs_hkl_curr.data["h"], + bvs_hkl_curr.data["k"], + bvs_hkl_curr.data["l"], + ) = np.vsplit(vectors_hkl, 3) + + # add to the PLA + bvs_hkl[Rx, Ry] += bvs_hkl_curr + + # add the PLA to the Strainmap object + self.bragg_vectors_indexed_hkl = bvs_hkl + def set_max_peak_spacing( self, max_peak_spacing, @@ -501,8 +573,8 @@ def fit_basis_vectors( ("qx", float), ("qy", float), ("intensity", float), - ("h", int), - ("k", int), + ("g1_ind", int), + ("g2_ind", int), ], shape=self.braggvectors.Rshape, ) @@ -538,8 +610,8 @@ def fit_basis_vectors( pl.data["qx"][i], pl.data["qy"][i], pl.data["intensity"][i], - self.braggdirections.data["h"][ind], - self.braggdirections.data["k"][ind], + self.braggdirections.data["g1_ind"][ind], + self.braggdirections.data["g2_ind"][ind], ) ) self.bragg_vectors_indexed = indexed_braggpeaks @@ -1178,11 +1250,16 @@ def show_bragg_indexing( The display image bragg_directions : PointList The Bragg scattering directions. Must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. + ('qx','qy','h', and 'k') or ('qx','qy','g1_ind', and 'g2_ind'. Optionally may also have 'l'. """ assert isinstance(bragg_directions, PointList) - for k in ("qx", "qy", "h", "k"): - assert k in bragg_directions.data.dtype.fields + # checking if it has h, k or g1_ind, g2_ind + assert all( + key in bragg_directions.data.dtype.names for key in ("qx", "qy", "h", "k") + ) or all( + key in bragg_directions.data.dtype.names + for key in ("qx", "qy", "g1_ind", "g2_ind") + ), 'pointlist must contain ("qx", "qy", "h", "k") or ("qx", "qy", "g1_ind", "g2_ind") fields' if figax is None: fig, ax = show(ar, returnfig=True, **kwargs) @@ -1201,6 +1278,7 @@ def show_bragg_indexing( "pointsize": pointsize, "pointcolor": pointcolor, } + # this can take ("qx", "qy", "h", "k") or ("qx", "qy", "g1_ind", "g2_ind") fields add_bragg_index_labels(ax, d) if returnfig: diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py index 8d6e2a891..54443b68f 100644 --- a/py4DSTEM/process/utils/single_atom_scatter.py +++ b/py4DSTEM/process/utils/single_atom_scatter.py @@ -1,5 +1,6 @@ import numpy as np import os +from scipy.special import kn class single_atom_scatter(object): @@ -46,6 +47,41 @@ def electron_scattering_factor(self, Z, gsq, units="A"): elif units == "A": return fe + def projected_potential(self, Z, R): + ai = self.e_scattering_factors[Z - 1, 0:10:2] + bi = self.e_scattering_factors[Z - 1, 1:10:2] + + # Planck's constant in Js + h = 6.62607004e-34 + # Electron rest mass in kg + me = 9.10938356e-31 + # Electron charge in Coulomb + qe = 1.60217662e-19 + # Electron charge in V-Angstroms + # qe = 14.4 + # Permittivity of vacuum + eps_0 = 8.85418782e-12 + # Bohr's constant + a_0 = 5.29177210903e-11 + + fe = np.zeros_like(R) + for i in range(5): + pre = 2 * np.pi / bi[i] ** 0.5 + fe += (ai[i] / bi[i] ** 1.5) * (kn(0, pre * R) + R * kn(1, pre * R)) + + # Scale output units + # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*qe) + # fe *= 2*np.pi**2 / kappa + # # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + + # # kappa = (4*np.pi*eps_0) / (2*np.pi*a_0*me) + # # return fe * 2 * np.pi**2 # / kappa + # # if units == "VA": + # return h**2 / (2 * np.pi * me * qe) * 1e18 * fe + # # elif units == "A": + # # return fe * 2 * np.pi**2 / kappa + return fe + def get_scattering_factor( self, elements=None, composition=None, q_coords=None, units=None ): diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index e2bf4307c..ddeeb2c36 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -11,16 +11,6 @@ import matplotlib.font_manager as fm from emdfile import tqdmnd -from py4DSTEM.process.utils.multicorr import upsampled_correlation -from py4DSTEM.preprocess.utils import make_Fourier_coords2D - -try: - from IPython.display import clear_output -except ImportError: - - def clear_output(wait=True): - pass - try: import cupy as cp @@ -103,12 +93,7 @@ def electron_wavelength_angstrom(E_eV): c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 return lam @@ -117,15 +102,8 @@ def electron_interaction_parameter(E_eV): e = 1.602177 * 10**-19 c = 299792458 h = 6.62607 * 10**-34 - lam = ( - h - / ma.sqrt(2 * m * e * E_eV) - / ma.sqrt(1 + e * E_eV / 2 / m / c**2) - * 10**10 - ) - sigma = ( - (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) - ) + lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10 + sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) return sigma @@ -185,25 +163,6 @@ def get_qx_qy_1d(M, dx=[1, 1], fft_shifted=False): return qxa, qya -def make_Fourier_coords2D(Nx, Ny, pixelSize=1): - """ - Generates Fourier coordinates for a (Nx,Ny)-shaped 2D array. - Specifying the pixelSize argument sets a unit size. - """ - if hasattr(pixelSize, "__len__"): - assert len(pixelSize) == 2, "pixelSize must either be a scalar or have length 2" - pixelSize_x = pixelSize[0] - pixelSize_y = pixelSize[1] - else: - pixelSize_x = pixelSize - pixelSize_y = pixelSize - - qx = np.fft.fftfreq(Nx, pixelSize_x) - qy = np.fft.fftfreq(Ny, pixelSize_y) - qy, qx = np.meshgrid(qy, qx) - return qx, qy - - def get_CoM(ar, device="cpu", corner_centered=False): """ Finds and returns the center of mass of array ar. @@ -458,6 +417,7 @@ def fourier_resample( bandlimit_nyquist=None, bandlimit_power=2, dtype=np.float32, + conserve_array_sums=False, ): """ Resize a 2D array along any dimension, using Fourier interpolation / extrapolation. @@ -476,6 +436,7 @@ def fourier_resample( bandlimit_nyquist (float): Gaussian filter information limit in Nyquist units (0.5 max in both directions) bandlimit_power (float): Gaussian filter power law scaling (higher is sharper) dtype (numpy dtype): datatype for binned array. default is single precision float + conserve_arrray_sums (bool): If True, the sums of the array are conserved Returns: the resized array (2D/4D numpy array) @@ -673,7 +634,8 @@ def fourier_resample( array_resize = np.maximum(array_resize, 0) # Normalization - array_resize = array_resize * scale_output + if not conserve_array_sums: + array_resize = array_resize * scale_output return array_resize diff --git a/py4DSTEM/process/wholepatternfit/wp_models.py b/py4DSTEM/process/wholepatternfit/wp_models.py index c0907a11e..e437916eb 100644 --- a/py4DSTEM/process/wholepatternfit/wp_models.py +++ b/py4DSTEM/process/wholepatternfit/wp_models.py @@ -289,12 +289,16 @@ def __init__( "radius": Parameter(radius), "sigma": Parameter(sigma), "intensity": Parameter(intensity), - "x center": WPF.coordinate_model.params["x center"] - if global_center - else Parameter(x0), - "y center": WPF.coordinate_model.params["y center"] - if global_center - else Parameter(y0), + "x center": ( + WPF.coordinate_model.params["x center"] + if global_center + else Parameter(x0) + ), + "y center": ( + WPF.coordinate_model.params["y center"] + if global_center + else Parameter(y0) + ), } super().__init__(name, params, model_type=WPFModelType.AMORPHOUS) diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index 904dceb29..b50a21de2 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -1,61 +1,96 @@ #### this file contains a function/s that will check if various # libaries/compute options are available import importlib -from operator import mod - -# list of modules we expect/may expect to be installed -# as part of a standard py4DSTEM installation -# this needs to be the import name e.g. import mp_api not mp-api -modules = [ - "crystal4D", - "cupy", - "dask", - "dill", - "distributed", - "gdown", - "h5py", - "ipyparallel", - "jax", - "matplotlib", - "mp_api", - "ncempy", - "numba", - "numpy", - "pymatgen", - "skimage", - "sklearn", - "scipy", - "tensorflow", - "tensorflow-addons", - "tqdm", -] - -# currently this was copy and pasted from setup.py, -# hopefully there's a programatic way to do this. -module_depenencies = { - "base": [ - "numpy", - "scipy", - "h5py", - "ncempy", - "matplotlib", - "skimage", - "sklearn", - "tqdm", - "dill", - "gdown", - "dask", - "distributed", - ], - "ipyparallel": ["ipyparallel", "dill"], - "cuda": ["cupy"], - "acom": ["pymatgen", "mp_api"], - "aiml": ["tensorflow", "tensorflow-addons", "crystal4D"], - "aiml-cuda": ["tensorflow", "tensorflow-addons", "crystal4D", "cupy"], - "numba": ["numba"], +from importlib.metadata import requires +import re +from importlib.util import find_spec + +# need a mapping of pypi/conda names to import names +import_mapping_dict = { + "scikit-image": "skimage", + "scikit-learn": "sklearn", + "scikit-optimize": "skopt", + "mp-api": "mp_api", } +# programatically get all possible requirements in the import name style +def get_modules_list(): + # Get the dependencies from the installed distribution + dependencies = requires("py4DSTEM") + + # Define a regular expression pattern for splitting on '>', '>=', '=' + delimiter_pattern = re.compile(r">=|>|==|<|<=") + + # Extract only the module names without versions + module_names = [ + delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip() + for dependency in dependencies + ] + + # translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api + for index, module in enumerate(module_names): + if module in import_mapping_dict.keys(): + module_names[index] = import_mapping_dict[module] + + return module_names + + +# programatically get all possible requirements in the import name style, +# split into a dict where optional import names are keys +def get_modules_dict(): + package_name = "py4DSTEM" + # Get the dependencies from the installed distribution + dependencies = requires(package_name) + + # set the dictionary for modules and packages to go into + # optional dependencies will be added as they are discovered + modules_dict = { + "base": [], + } + # loop over the dependencies + for depend in dependencies: + # all the optional have extra in the name + # if its not there append it to base + if "extra" not in depend: + # String looks like: 'numpy>=1.19' + modules_dict["base"].append(depend) + + # if it has extra in the string + else: + # get the name of the optional name + # depend looks like this 'numba>=0.49.1; extra == "numba"' + # grab whatever is in the double quotes i.e. numba + optional_name = re.search(r'"(.*?)"', depend).group(1) + # if the optional name is not in the dict as a key i.e. first requirement of hte optional dependency + if optional_name not in modules_dict: + modules_dict[optional_name] = [depend] + # if the optional_name is already in the dict then just append it to the list + else: + modules_dict[optional_name].append(depend) + # STRIP all the versioning and semi-colons + # Define a regular expression pattern for splitting on '>', '>=', '=' + delimiter_pattern = re.compile(r">=|>|==|<|<=") + for key, val in modules_dict.items(): + # modules_dict[key] = [dependency.split(';')[0].split(' ')[0] for dependency in val] + modules_dict[key] = [ + delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip() + for dependency in val + ] + + # translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api + for key, val in modules_dict.items(): + for index, module in enumerate(val): + if module in import_mapping_dict.keys(): + val[index] = import_mapping_dict[module] + + return modules_dict + + +# module_depenencies = get_modules_dict() +modules = get_modules_list() + + #### Class and Functions to Create Coloured Strings #### class colours: CEND = "\x1b[0m" @@ -140,6 +175,7 @@ def create_underline(s: str) -> str: ### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used +# get the state of each modules as a dict key-val e.g. "numpy" : True def get_import_states(modules: list = modules) -> dict: """ Check the ability to import modules and store the results as a boolean value. Returns as a dict. @@ -163,16 +199,17 @@ def get_import_states(modules: list = modules) -> dict: return import_states_dict +# Check def get_module_states(state_dict: dict) -> dict: - """_summary_ - - Args: - state_dict (dict): _description_ + """ + given a state dict for all modules e.g. "numpy" : True, + this parses through and checks if all modules required for a state are true - Returns: - dict: _description_ + returns dict "base": True, "ai-ml": False etc. """ + # get the modules_dict + module_depenencies = get_modules_dict() # create an empty dict to put module states into: module_states = {} @@ -196,13 +233,12 @@ def get_module_states(state_dict: dict) -> dict: def print_import_states(import_states: dict) -> None: - """_summary_ - - Args: - import_states (dict): _description_ + """ + print with colours if the library could be imported or not + takes dict + "numpy" : True -> prints success + "pymatgen" : False -> prints failure - Returns: - _type_: _description_ """ # m is the name of the import module # state is whether it was importable @@ -223,13 +259,11 @@ def print_import_states(import_states: dict) -> None: def print_module_states(module_states: dict) -> None: - """_summary_ - - Args: - module_states (dict): _description_ - - Returns: - _type_: _description_ + """ + print with colours if all the imports required for module could be imported or not + takes dict + "base" : True -> prints success + "ai-ml" : Fasle -> prints failure """ # Print out the state of all the modules in colour code # key is the name of a py4DSTEM Module @@ -248,25 +282,33 @@ def print_module_states(module_states: dict) -> None: return None -def perfrom_extra_checks( +def perform_extra_checks( import_states: dict, verbose: bool, gratuitously_verbose: bool, **kwargs ) -> None: """_summary_ Args: - import_states (dict): _description_ - verbose (bool): _description_ - gratuitously_verbose (bool): _description_ + import_states (dict): dict of modules and if they could be imported or not + verbose (bool): will show module states and all import states + gratuitously_verbose (bool): will run extra checks - Currently only for cupy Returns: _type_: _description_ """ - - # print a output module - extra_checks_message = "Running Extra Checks" - extra_checks_message = create_bold(extra_checks_message) - print(f"{extra_checks_message}") - # For modules that import run any extra checks + if gratuitously_verbose: + # print a output module + extra_checks_message = "Running Extra Checks" + extra_checks_message = create_bold(extra_checks_message) + print(f"{extra_checks_message}") + # For modules that import run any extra checks + # get all the dependencies + dependencies = requires("py4DSTEM") + # Extract only the module names with versions + depends_with_requirements = [ + dependency.split(";")[0] for dependency in dependencies + ] + # print(depends_with_requirements) + # need to go from for key, val in import_states.items(): if val: # s = create_underline(key.capitalize()) @@ -281,7 +323,10 @@ def perfrom_extra_checks( if gratuitously_verbose: s = create_underline(key.capitalize()) print(s) - print_no_extra_checks(key) + # check + generic_versions( + key, depends_with_requires=depends_with_requirements + ) else: pass @@ -304,7 +349,7 @@ def import_tester(m: str) -> bool: # try and import the module try: importlib.import_module(m) - except: + except Exception: state = False return state @@ -324,6 +369,7 @@ def check_module_functionality(state_dict: dict) -> None: # create an empty dict to put module states into: module_states = {} + module_depenencies = get_modules_dict() # key is the name of the module e.g. ACOM # val is a list of its dependencies @@ -359,6 +405,45 @@ def check_module_functionality(state_dict: dict) -> None: #### ADDTIONAL CHECKS #### +def generic_versions(module: str, depends_with_requires: list[str]) -> None: + # module will be like numpy, skimage + # depends_with_requires look like: numpy >= 19.0, scikit-image + # get module_translated_name + # mapping scikit-image : skimage + for key, value in import_mapping_dict.items(): + # if skimage == skimage get scikit-image + # print(f"{key = } - {value = } - {module = }") + if module in value: + module_depend_name = key + break + else: + # if cant find mapping set the search name to the same + module_depend_name = module + # print(f"{module_depend_name = }") + # find the requirement + for depend in depends_with_requires: + if module_depend_name in depend: + spec_required = depend + # print(f"{spec_required = }") + # get the version installed + spec_installed = find_spec(module) + if spec_installed is None: + s = f"{module} unable to import - {spec_required} required" + s = create_failure(s) + s = f"{s: <80}" + print(s) + + else: + try: + version = importlib.metadata.version(module_depend_name) + except Exception: + version = "Couldn't test version" + s = f"{module} imported: {version = } - {spec_required} required" + s = create_warning(s) + s = f"{s: <80}" + print(s) + + def check_cupy_gpu(gratuitously_verbose: bool, **kwargs): """ This function performs some additional tests which may be useful in @@ -375,25 +460,18 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs): # check that CUDA is detected correctly cuda_availability = cp.cuda.is_available() if cuda_availability: - s = " CUDA is Available " + s = f" CUDA is Available " s = create_success(s) s = f"{s: <80}" print(s) else: - s = " CUDA is Unavailable " + s = f" CUDA is Unavailable " s = create_failure(s) s = f"{s: <80}" print(s) # Count how many GPUs Cupy can detect - # probably should change this to a while loop ... - for i in range(24): - try: - d = cp.cuda.Device(i) - hasattr(d, "attributes") - except: - num_gpus_detected = i - break + num_gpus_detected = cp.cuda.runtime.getDeviceCount() # print how many GPUs were detected, filter for a couple of special conditons if num_gpus_detected == 0: @@ -448,7 +526,9 @@ def print_no_extra_checks(m: str): # dict of extra check functions -funcs_dict = {"cupy": check_cupy_gpu} +funcs_dict = { + "cupy": check_cupy_gpu, +} #### main function used to check the configuration of the installation @@ -493,7 +573,7 @@ def check_config( print_import_states(states_dict) - perfrom_extra_checks( + perform_extra_checks( import_states=states_dict, verbose=verbose, gratuitously_verbose=gratuitously_verbose, diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 103751b29..08777437f 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.9" +__version__ = "0.14.14" diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index 32baff443..8c6c06d9f 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -553,8 +553,16 @@ def add_bragg_index_labels(ax, d): assert "bragg_directions" in d.keys() bragg_directions = d["bragg_directions"] assert isinstance(bragg_directions, PointList) - for k in ("qx", "qy", "h", "k"): - assert k in bragg_directions.data.dtype.fields + # check pointlist has ("qx", "qy", "h", "k") or ("qx", "qy", "g1_ind", "g2_ind") fields + assert all( + key in bragg_directions.data.dtype.names for key in ("qx", "qy", "h", "k") + ) or all( + key in bragg_directions.data.dtype.names + for key in ("qx", "qy", "g1_ind", "g2_ind") + ), 'pointlist must contain ("qx", "qy", "h", "k") or ("qx", "qy", "g1_ind", "g2_ind") fields' + + # for k in ("qx", "qy", "h", "k"): + # assert k in bragg_directions.data.dtype.fields include_l = True if "l" in bragg_directions.data.dtype.fields else False # offsets hoffset = d["hoffset"] if "hoffset" in d.keys() else 0 @@ -586,10 +594,21 @@ def add_bragg_index_labels(ax, d): x, y = bragg_directions.data["qx"][i], bragg_directions.data["qy"][i] x -= voffset y += hoffset - h, k = bragg_directions.data["h"][i], bragg_directions.data["k"][i] + # bit of a hack to get either "h" or "g1_ind" + h = ( + bragg_directions.data["h"][i] + if not ValueError + else bragg_directions.data["g1_ind"][i] + ) + k = ( + bragg_directions.data["k"][i] + if not ValueError + else bragg_directions.data["g2_ind"][i] + ) h = str(h) if h >= 0 else r"$\overline{{{}}}$".format(np.abs(h)) k = str(k) if k >= 0 else r"$\overline{{{}}}$".format(np.abs(k)) s = h + "," + k + # TODO might need to to add g3_ind check if include_l: l = bragg_directions.data["l"][i] l = str(l) if l >= 0 else r"$\overline{{{}}}$".format(np.abs(l)) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 8462eec7d..7430992e0 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -76,6 +76,7 @@ def show( theta=None, title=None, show_fft=False, + apply_hanning_window=True, show_cbar=False, **kwargs, ): @@ -305,6 +306,8 @@ def show( which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits for scalebar. If False, no scalebar is added. show_fft (bool): if True, plots 2D-fft of array + apply_hanning_window (bool) + If True, a 2D Hann window is applied to the array before applying the FFT show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() @@ -369,9 +372,12 @@ def show( from py4DSTEM.visualize import show if show_fft: - n0 = ar.shape - w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] - ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + if apply_hanning_window: + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + else: + ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) for a0 in range(num_images): im = show( ar[a0], @@ -451,7 +457,12 @@ def show( # Otherwise, plot one image if show_fft: if combine_images is False: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + if apply_hanning_window: + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) + else: + ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) # get image from a masked array if mask is not None: diff --git a/py4DSTEM/visualize/vis_grid.py b/py4DSTEM/visualize/vis_grid.py index d24b0b8d8..cb2581e01 100644 --- a/py4DSTEM/visualize/vis_grid.py +++ b/py4DSTEM/visualize/vis_grid.py @@ -205,7 +205,7 @@ def show_image_grid( ax = axs[i, j] N = i * W + j # make titles - if type(title) == list: + if type(title) == list and N < len(title): print_title = title[N] else: print_title = None @@ -281,7 +281,11 @@ def show_image_grid( ) else: _, _ = show( - ar, figax=(fig, ax), returnfig=True, title=print_title, **kwargs + ar, + figax=(fig, ax), + returnfig=True, + title=print_title, + **kwargs, ) except IndexError: ax.axis("off") diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 88b7d7815..2db48e371 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -504,20 +504,34 @@ def show_origin_meas(data): show_image_grid(get_ar=lambda i: [qx, qy][i], H=1, W=2, cmap="RdBu") -def show_origin_fit(data): +def show_origin_fit( + data, + plot_range=None, + axsize=(3, 3), +): """ Show the measured, fit, and residuals of the origin positions. - Args: - data (DataCube or Calibration or (3,2)-tuple of arrays - ((qx0_meas,qy0_meas),(qx0_fit,qy0_fit),(qx0_residuals,qy0_residuals)) + Parameters + ---------- + data: (DataCube or Calibration or (3,2)-tuple of arrays + ((qx0_meas,qy0_meas),(qx0_fit,qy0_fit),(qx0_residuals,qy0_residuals)) + plot_range: (tuple, list, or np.array) + Plotting range in units of pixels + axsize: (tuple) + Size of each plot axis + + Returns + ------- + + """ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube if isinstance(data, tuple): assert len(data) == 3 - qx0_meas, qy_meas = data[0] + qx0_meas, qy0_meas = data[0] qx0_fit, qy0_fit = data[1] qx0_residuals, qy0_residuals = data[2] elif isinstance(data, DataCube): @@ -531,18 +545,39 @@ def show_origin_fit(data): else: raise Exception("data must be of type Datacube or Calibration or tuple") + # Centered intensity plots + qx0_meas_plot = qx0_meas - np.median(qx0_meas) + qy0_meas_plot = qy0_meas - np.median(qy0_meas) + qx0_fit_plot = qx0_fit - np.median(qx0_fit) + qy0_fit_plot = qy0_fit - np.median(qy0_fit) + + # Determine plotting range + if plot_range is None: + plot_range = np.array((-1.0, 1.0)) * np.max( + (np.abs(qx0_fit_plot), np.abs(qy0_fit_plot)) + ) + else: + plot_range = np.array(plot_range) + if plot_range.ndim == 0: + plot_range = np.array((-1.0, 1.0)) * plot_range + + # plotting show_image_grid( get_ar=lambda i: [ - qx0_meas, - qx0_fit, + qx0_meas_plot, + qx0_fit_plot, qx0_residuals, - qy0_meas, - qy0_fit, + qy0_meas_plot, + qy0_fit_plot, qy0_residuals, ][i], H=2, W=3, cmap="RdBu", + intensity_range="absolute", + vmin=plot_range[0], + vmax=plot_range[1], + axsize=axsize, ) @@ -801,22 +836,31 @@ def show_complex( ) if scalebar is True: scalebar = { - "Nx": ar_complex[0].shape[0], - "Ny": ar_complex[0].shape[1], + "Nx": rgb[0].shape[0], + "Ny": rgb[0].shape[1], "pixelsize": pixelsize, "pixelunits": pixelunits, } add_scalebar(ax[0, 0], scalebar) else: + figsize = kwargs.pop("axsize", None) + figsize = kwargs.pop("figsize", figsize) + fig, ax = show( - rgb, vmin=0, vmax=1, intensity_range="absolute", returnfig=True, **kwargs + rgb, + vmin=0, + vmax=1, + intensity_range="absolute", + returnfig=True, + figsize=figsize, + **kwargs, ) if scalebar is True: scalebar = { - "Nx": ar_complex.shape[0], - "Ny": ar_complex.shape[1], + "Nx": rgb.shape[0], + "Ny": rgb.shape[1], "pixelsize": pixelsize, "pixelunits": pixelunits, } @@ -826,7 +870,7 @@ def show_complex( # add color bar if cbar: if is_grid: - for ax_flat in ax.flatten(): + for ax_flat in ax.flatten()[: len(rgb)]: divider = make_axes_locatable(ax_flat) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) diff --git a/setup.py b/setup.py index 2d828289a..bdb2c48ff 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ author_email="ben.savitzky@gmail.com", license="GNU GPLv3", keywords="STEM 4DSTEM", - python_requires=">=3.9,<3.13", + python_requires=">=3.10", install_requires=[ "numpy >= 1.19", "scipy >= 1.5.2", @@ -34,7 +34,7 @@ "scikit-optimize >= 0.9.0", "tqdm >= 4.46.1", "dill >= 0.3.3", - "gdown >= 4.7.1", + "gdown >= 5.1.0", "dask >= 2.3.0", "distributed >= 2.3.0", "emdfile >= 0.0.14", @@ -47,12 +47,18 @@ "ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"], "cuda": ["cupy >= 10.0.0"], "acom": ["pymatgen >= 2022", "mp-api == 0.24.1"], - "aiml": ["tensorflow == 2.4.1", "tensorflow-addons <= 0.14.0", "crystal4D"], + "aiml": [ + "tensorflow <= 2.10.0", + "tensorflow-addons <= 0.16.1", + "crystal4D", + "typeguard == 2.7", + ], "aiml-cuda": [ - "tensorflow == 2.4.1", - "tensorflow-addons <= 0.14.0", + "tensorflow <= 2.10.0", + "tensorflow-addons <= 0.16.1", "crystal4D", "cupy >= 10.0.0", + "typeguard == 2.7", ], "numba": ["numba >= 0.49.1"], },