Minor refactoring

This commit is contained in:
Atsushi Togo 2024-02-16 15:46:49 +09:00
parent 21d7e837f6
commit 2966ffcc93
2 changed files with 48 additions and 25 deletions

View File

@ -39,6 +39,7 @@ from __future__ import annotations
import sys import sys
import textwrap import textwrap
import warnings import warnings
from collections.abc import Sequence
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
@ -145,8 +146,8 @@ class Phonopy:
def __init__( def __init__(
self, self,
unitcell, unitcell,
supercell_matrix=None, supercell_matrix: Optional[Union[Sequence, np.ndarray]] = None,
primitive_matrix=None, primitive_matrix: Optional[Union[str, Sequence, np.ndarray]] = None,
nac_params=None, nac_params=None,
factor=VaspToTHz, factor=VaspToTHz,
frequency_scale_factor=None, frequency_scale_factor=None,
@ -3869,7 +3870,7 @@ class Phonopy:
def _set_primitive_matrix( def _set_primitive_matrix(
self, primitive_matrix self, primitive_matrix
) -> Optional[Union[str, list, np.ndarray]]: ) -> Optional[Union[str, np.ndarray]]:
pmat = get_primitive_matrix(primitive_matrix, symprec=self._symprec) pmat = get_primitive_matrix(primitive_matrix, symprec=self._symprec)
if isinstance(pmat, str) and pmat == "auto": if isinstance(pmat, str) and pmat == "auto":
return guess_primitive_matrix(self._unitcell, symprec=self._symprec) return guess_primitive_matrix(self._unitcell, symprec=self._symprec)

View File

@ -1644,17 +1644,24 @@ def determinant(m):
def get_primitive_matrix( def get_primitive_matrix(
pmat: Optional[Union[str, np.ndarray, Sequence]] = None, pmat: Optional[Union[str, np.ndarray, Sequence]] = None,
symprec=1e-5, symprec: float = 1e-5,
) -> Optional[Union[str, list, np.ndarray]]: ) -> Optional[Union[str, np.ndarray]]:
"""Find primitive matrix from primitive cell. """Find primitive matrix from primitive cell.
None is equivalent to "P" but None is returned. None is equivalent to "P" but None is returned.
``pmat`` can be Parameters
----------
pmat : str, np.ndarray, Sequency, or None
symbol of centring type: "P", "F", "I", "A", "C", "R"
"auto" : estimates a centring type.
3x3 matrix (can be flattened, i.e., 9 elements)
symprec : float
Tolerance.
- a symbol of centring type: "P", "F", "I", "A", "C", "R" Returns
- "auto" : estimates a centring type. -------
- 3x3 matrix (can be flattened, i.e., 9 elements) None or 3x3 np.ndarray representing transformation matrix to primitive cell.
""" """
if isinstance(pmat, str) and pmat in ("P", "F", "I", "A", "C", "R", "auto"): if isinstance(pmat, str) and pmat in ("P", "F", "I", "A", "C", "R", "auto"):
@ -1683,28 +1690,43 @@ def get_primitive_matrix(
return _pmat return _pmat
def get_primitive_matrix_by_centring(centring): def get_primitive_matrix_by_centring(centring) -> Optional[np.ndarray]:
"""Return primitive matrix corresponding to centring.""" """Return primitive matrix corresponding to centring."""
if centring == "P": if centring == "P":
return [[1, 0, 0], [0, 1, 0], [0, 0, 1]] return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="double")
elif centring == "F": elif centring == "F":
return [[0, 1.0 / 2, 1.0 / 2], [1.0 / 2, 0, 1.0 / 2], [1.0 / 2, 1.0 / 2, 0]] return np.array(
[[0.0, 1.0 / 2, 1.0 / 2], [1.0 / 2, 0, 1.0 / 2], [1.0 / 2, 1.0 / 2, 0.0]],
dtype="double",
)
elif centring == "I": elif centring == "I":
return [ return np.array(
[-1.0 / 2, 1.0 / 2, 1.0 / 2], [
[1.0 / 2, -1.0 / 2, 1.0 / 2], [-1.0 / 2, 1.0 / 2, 1.0 / 2],
[1.0 / 2, 1.0 / 2, -1.0 / 2], [1.0 / 2, -1.0 / 2, 1.0 / 2],
] [1.0 / 2, 1.0 / 2, -1.0 / 2],
],
dtype="double",
)
elif centring == "A": elif centring == "A":
return [[1, 0, 0], [0, 1.0 / 2, -1.0 / 2], [0, 1.0 / 2, 1.0 / 2]] return np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0 / 2, -1.0 / 2], [0.0, 1.0 / 2, 1.0 / 2]],
dtype="double",
)
elif centring == "C": elif centring == "C":
return [[1.0 / 2, 1.0 / 2, 0], [-1.0 / 2, 1.0 / 2, 0], [0, 0, 1]] return np.array(
[[1.0 / 2, 1.0 / 2, 0], [-1.0 / 2, 1.0 / 2, 0], [0.0, 0.0, 1.0]],
dtype="double",
)
elif centring == "R": elif centring == "R":
return [ return np.array(
[2.0 / 3, -1.0 / 3, -1.0 / 3], [
[1.0 / 3, 1.0 / 3, -2.0 / 3], [2.0 / 3, -1.0 / 3, -1.0 / 3],
[1.0 / 3, 1.0 / 3, 1.0 / 3], [1.0 / 3, 1.0 / 3, -2.0 / 3],
] [1.0 / 3, 1.0 / 3, 1.0 / 3],
],
dtype="double",
)
else: else:
return None return None
@ -1722,7 +1744,7 @@ def guess_primitive_matrix(unitcell: PhonopyAtoms, symprec: float = 1e-5):
return np.array(np.dot(np.linalg.inv(tmat), pmat), dtype="double", order="C") return np.array(np.dot(np.linalg.inv(tmat), pmat), dtype="double", order="C")
def shape_supercell_matrix(smat: Optional[Union[int, float, np.ndarray]]) -> np.ndarray: def shape_supercell_matrix(smat: Optional[Union[Sequence, np.ndarray]]) -> np.ndarray:
"""Reshape supercell matrix.""" """Reshape supercell matrix."""
if smat is None: if smat is None:
_smat = np.eye(3, dtype="intc", order="C") _smat = np.eye(3, dtype="intc", order="C")