diff --git a/phonopy/api_phonopy.py b/phonopy/api_phonopy.py index d086923b..668adb0b 100644 --- a/phonopy/api_phonopy.py +++ b/phonopy/api_phonopy.py @@ -69,10 +69,10 @@ from phonopy.interface.calculator import get_default_physical_units from phonopy.interface.fc_calculator import get_fc2 from phonopy.interface.phonopy_yaml import PhonopyYaml from phonopy.interface.pypolymlp import ( - PypolymlpData, PypolymlpParams, - develop_polymlp, + develop_mlp_by_pypolymlp, evalulate_polymlp, + load_polymlp, ) from phonopy.phonon.animation import write_animation from phonopy.phonon.band_structure import BandStructure, get_band_qpoints_by_seekpath @@ -1242,7 +1242,11 @@ class Phonopy: if self._primitive.masses is not None: self._set_dynamical_matrix() - def develop_mlp(self, params: Optional[Union[PypolymlpParams, dict]] = None): + def develop_mlp( + self, + params: Optional[Union[PypolymlpParams, dict, str]] = None, + test_size: float = 0.1, + ): """Develop MLP of pypolymlp. Parameters @@ -1250,36 +1254,29 @@ class Phonopy: params : PypolymlpParams or dict, optional Parameters for developing MLP. Default is None. When dict is given, PypolymlpParams instance is created from the dict. + test_size : float, optional + Training and test data are splitted by this ratio. test_size=0.1 + means the first 90% of the data is used for training and the rest + is used for test. Default is 0.1. """ if self._mlp_dataset is None: raise RuntimeError("MLP dataset is not set.") - if isinstance(params, dict): - _params = PypolymlpParams(**params) - else: - _params = params - - disps = self._mlp_dataset["displacements"] - forces = self._mlp_dataset["forces"] - energies = self._mlp_dataset["supercell_energies"] - n = int(len(disps) * 0.9) - train_data = PypolymlpData( - displacements=disps[:n], forces=forces[:n], supercell_energies=energies[:n] - ) - test_data = PypolymlpData( - displacements=disps[n:], forces=forces[n:], supercell_energies=energies[n:] - ) - self._mlp = develop_polymlp( + self._mlp = develop_mlp_by_pypolymlp( + self._mlp_dataset, self._supercell, - train_data, - test_data, - params=_params, - verbose=self._log_level - 1 > 0, + params=params, + test_size=test_size, + log_level=self._log_level, ) + def load_mlp(self, filename: str = "phonopy.pmlp"): + """Load machine learning potential of pypolymlp.""" + self._mlp = load_polymlp(filename=filename) + def evaluate_mlp(self): - """Evaluate the machine learning potential of pypolymlp. + """Evaluate machine learning potential of pypolymlp. This method calculates the supercell energies and forces from the MLP for the displacements in self._dataset of type 2. The results are stored diff --git a/phonopy/cui/load.py b/phonopy/cui/load.py index 4624c665..b597ec67 100644 --- a/phonopy/cui/load.py +++ b/phonopy/cui/load.py @@ -75,6 +75,8 @@ def load( is_symmetry: bool = True, symmetrize_fc: bool = True, is_compact_fc: bool = True, + use_pypolymlp: bool = False, + mlp_params: Optional[dict] = None, store_dense_svecs: bool = False, use_SNF_supercell: bool = False, symprec: float = 1e-5, @@ -201,6 +203,10 @@ def load( True: (primitive, supecell, 3, 3) False: (supercell, supecell, 3, 3) where 'supercell' and 'primitive' indicate number of atoms in these cells. Default is True. + use_pypolymlp : bool, optional + Use pypolymlp for generating force constants. Default is False. + mlp_params : dict, optional + A set of parameters used by machine learning potentials. store_dense_svecs : bool, optional Dataset of shortest vectors between atoms in primitive cell and supercell is stored in the dense format when this is True. Default is @@ -317,6 +323,8 @@ def load( produce_fc=produce_fc, symmetrize_fc=symmetrize_fc, is_compact_fc=is_compact_fc, + use_pypolymlp=use_pypolymlp, + mlp_params=mlp_params, log_level=log_level, ) diff --git a/phonopy/cui/load_helper.py b/phonopy/cui/load_helper.py index fad4b3c0..e70a26a4 100644 --- a/phonopy/cui/load_helper.py +++ b/phonopy/cui/load_helper.py @@ -37,7 +37,8 @@ from __future__ import annotations import pathlib -from typing import Optional +from dataclasses import asdict +from typing import Optional, Union import numpy as np @@ -57,6 +58,10 @@ from phonopy.interface.calculator import ( get_force_constant_conversion_factor, read_crystal_structure, ) +from phonopy.interface.pypolymlp import ( + PypolymlpParams, + parse_mlp_params, +) from phonopy.structure.atoms import PhonopyAtoms from phonopy.structure.cells import get_primitive_matrix from phonopy.structure.dataset import forces_in_dataset @@ -175,8 +180,9 @@ def read_force_constants_from_hdf5( def set_dataset_and_force_constants( phonon: Phonopy, - dataset: dict, - fc: Optional[np.ndarray], # From phonopy_yaml + dataset: Optional[dict], + phonopy_yaml_filename: Optional[str] = None, + fc: Optional[np.ndarray] = None, # From phonopy_yaml force_constants_filename: Optional[str] = None, force_sets_filename: Optional[str] = None, fc_calculator: Optional[str] = None, @@ -185,23 +191,191 @@ def set_dataset_and_force_constants( symmetrize_fc: bool = True, is_compact_fc: bool = True, use_pypolymlp: bool = False, + mlp_params: Optional[dict] = None, + displacement_distance: Optional[float] = None, + number_of_snapshots: Optional[int] = None, + random_seed: Optional[int] = None, log_level: int = 0, ): """Set displacement-force dataset and force constants.""" + _set_dataset( + phonon, + dataset, + phonopy_yaml_filename, + force_sets_filename, + use_pypolymlp, + log_level, + ) + + _set_force_constants( + phonon, + phonopy_yaml_filename, + fc, + force_constants_filename, + is_compact_fc, + log_level, + ) + + if use_pypolymlp: + _run_pypolymlp_to_compute_forces( + phonon, + mlp_params, + displacement_distance=displacement_distance, + number_of_snapshots=number_of_snapshots, + random_seed=random_seed, + log_level=log_level, + ) + + if ( + phonon.force_constants is None + and produce_fc + and forces_in_dataset(phonon.dataset) + ): + _produce_force_constants( + phonon, + fc_calculator, + fc_calculator_options, + symmetrize_fc, + is_compact_fc, + log_level, + ) + + +def check_nac_params(nac_params: dict, unitcell: PhonopyAtoms, pmat: np.ndarray): + """Check number of Born effective charges.""" + borns = nac_params["born"] + if len(borns) != np.rint(len(unitcell) * np.linalg.det(pmat)).astype(int): + msg = "Number of Born effective charges is not consistent with the cell." + raise ValueError(msg) + + +def _set_dataset( + phonon: Phonopy, + dataset: Optional[dict], + phonopy_yaml_filename: Optional[str] = None, + force_sets_filename: Optional[str] = None, + use_pypolymlp: bool = False, + log_level: int = 0, +): natom = len(phonon.supercell) - - # dataset and fc are those obtained from phonopy_yaml unless None. - if dataset is not None: - if use_pypolymlp: - phonon.mlp_dataset = dataset - else: - phonon.dataset = dataset - if fc is not None: - phonon.force_constants = fc - - _fc = None _dataset = None - if force_constants_filename is not None: + _force_sets_filename = None + if forces_in_dataset(dataset): + _dataset = dataset + _force_sets_filename = phonopy_yaml_filename + elif force_sets_filename is not None: + _dataset = parse_FORCE_SETS(natom=natom, filename=force_sets_filename) + _force_sets_filename = force_sets_filename + elif pathlib.Path("FORCE_SETS").exists(): + _dataset = parse_FORCE_SETS(natom=natom) + _force_sets_filename = "FORCE_SETS" + else: + _dataset = dataset + + if log_level: + if forces_in_dataset(_dataset): + print(f'Displacement-force dataset was read from "{_force_sets_filename}".') + elif _dataset is not None: + print(f'Displacement dataset was read from "{_force_sets_filename}".') + + if use_pypolymlp: + phonon.mlp_dataset = _dataset + else: + phonon.dataset = _dataset + + +def _run_pypolymlp_to_compute_forces( + phonon: Phonopy, + mlp_params: Union[str, dict, PypolymlpParams], + displacement_distance: Optional[float] = None, + number_of_snapshots: Optional[int] = None, + random_seed: Optional[int] = None, + mlp_filename: str = "phonopy.pmlp", + log_level: int = 0, +): + """Run pypolymlp to compute forces.""" + if log_level: + print("-" * 29 + " pypolymlp start " + "-" * 30) + print("Pypolymlp is a generator of polynomial machine learning potentials.") + print("Please cite the paper: A. Seko, J. Appl. Phys. 133, 011101 (2023).") + print("Pypolymlp is developed at https://github.com/sekocha/pypolymlp.") + if mlp_params: + print("Parameters:") + for k, v in asdict(parse_mlp_params(mlp_params)).items(): + if v is not None: + print(f" {k}: {v}") + + if forces_in_dataset(phonon.mlp_dataset): + if log_level: + print("Developing MLPs by pypolymlp...", flush=True) + phonon.develop_mlp(params=mlp_params) + phonon.mlp.save_mlp(filename=mlp_filename) + if log_level: + print(f'MLPs were written into "{mlp_filename}"', flush=True) + else: + if pathlib.Path(mlp_filename).exists(): + if log_level: + print(f'Load MLPs from "{mlp_filename}".') + phonon.load_mlp(mlp_filename) + else: + raise RuntimeError(f'"{mlp_filename}" is not found.') + + if log_level: + print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True) + + if displacement_distance is None: + _displacement_distance = 0.001 + else: + _displacement_distance = displacement_distance + + if log_level: + if number_of_snapshots: + print("Generate random displacements") + print( + " Twice of number of snapshots will be generated " + "for plus-minus displacements." + ) + else: + print("Generate displacements") + print( + f" Displacement distance: {_displacement_distance:.5f}".rstrip("0").rstrip( + "." + ) + ) + phonon.generate_displacements( + distance=_displacement_distance, + is_plusminus=True, + number_of_snapshots=number_of_snapshots, + random_seed=random_seed, + ) + + if log_level: + print( + f"Evaluate forces in {len(phonon.displacements)} supercells " + "by pypolymlp", + flush=True, + ) + + if phonon.supercells_with_displacements is None: + raise RuntimeError("Displacements are not set. Run generate_displacements.") + + phonon.evaluate_mlp() + + +def _set_force_constants( + phonon: Phonopy, + phonopy_yaml_filename: Optional[str] = None, + fc: Optional[np.ndarray] = None, # From phonopy_yaml + force_constants_filename: Optional[str] = None, + is_compact_fc: bool = True, + log_level: int = 0, +): + _fc = None + _force_constants_filename = None + if fc is not None: + _fc = fc + _force_constants_filename = phonopy_yaml_filename + elif force_constants_filename is not None: _fc = _read_force_constants_file( phonon, force_constants_filename, @@ -209,10 +383,7 @@ def set_dataset_and_force_constants( log_level=log_level, ) _force_constants_filename = force_constants_filename - elif force_sets_filename is not None: - _dataset = parse_FORCE_SETS(natom=natom, filename=force_sets_filename) - _force_sets_filename = force_sets_filename - elif phonon.forces is None and phonon.force_constants is None: + elif phonon.force_constants is None: # unless provided these from phonopy_yaml. if pathlib.Path("FORCE_CONSTANTS").exists(): _fc = _read_force_constants_file( @@ -230,50 +401,15 @@ def set_dataset_and_force_constants( log_level=log_level, ) _force_constants_filename = "force_constants.hdf5" - elif pathlib.Path("FORCE_SETS").exists(): - _dataset = parse_FORCE_SETS(natom=natom) - _force_sets_filename = "FORCE_SETS" if _fc is not None: + if not is_compact_fc and _fc.shape[0] != _fc.shape[1]: + _fc = compact_fc_to_full_fc(phonon, _fc, log_level=log_level) + elif is_compact_fc and _fc.shape[0] == _fc.shape[1]: + _fc = full_fc_to_compact_fc(phonon, _fc, log_level=log_level) phonon.force_constants = _fc if log_level: - print('Force constants were read from "%s".' % _force_constants_filename) - - if phonon.force_constants is None: - # Overwrite dataset - if _dataset is not None: - if phonon.dataset is None: - is_overwritten = False - else: - is_overwritten = ( - "first_atoms" in phonon.dataset or "displacements" in phonon.dataset - ) - phonon.dataset = _dataset - if log_level: - print(f'Force sets were read from "{_force_sets_filename}".') - if is_overwritten: - print( - f"Displacements in dataset were overwritten by those in " - f'"{_force_sets_filename}".' - ) - - if produce_fc and forces_in_dataset(phonon.dataset): - _produce_force_constants( - phonon, - fc_calculator, - fc_calculator_options, - symmetrize_fc, - is_compact_fc, - log_level, - ) - - -def check_nac_params(nac_params: dict, unitcell: PhonopyAtoms, pmat: np.ndarray): - """Check number of Born effective charges.""" - borns = nac_params["born"] - if len(borns) != np.rint(len(unitcell) * np.linalg.det(pmat)).astype(int): - msg = "Number of Born effective charges is not consistent with the cell." - raise ValueError(msg) + print(f'Force constants were read from "{_force_constants_filename}".') def _read_force_constants_file( diff --git a/phonopy/cui/phonopy_argparse.py b/phonopy/cui/phonopy_argparse.py index d0aa855a..b14f0307 100644 --- a/phonopy/cui/phonopy_argparse.py +++ b/phonopy/cui/phonopy_argparse.py @@ -518,6 +518,15 @@ def get_parser(fc_symmetry=False, is_nac=False, load_phonopy_yaml=False): default=None, help="Same behavior as MP tag", ) + parser.add_argument( + "--mlp-params", + dest="mlp_params", + default=None, + help=( + "Parameters for machine learning potentials as comma separated " + "string with the style of key = values" + ), + ) parser.add_argument( "--moment", dest="is_moment", diff --git a/phonopy/cui/phonopy_script.py b/phonopy/cui/phonopy_script.py index a2e7259e..fdb11219 100644 --- a/phonopy/cui/phonopy_script.py +++ b/phonopy/cui/phonopy_script.py @@ -260,11 +260,19 @@ def files_exist( def _finalize_phonopy( - log_level, settings: Settings, confs, phonon, filename="phonopy.yaml" + log_level, settings: Settings, confs, phonon: Phonopy, filename="phonopy.yaml" ): """Finalize phonopy.""" units = get_default_physical_units(phonon.calculator) + if phonon.mlp_dataset is not None: + mlp_eval_filename = "phonopy_mlp_eval_dataset.yaml" + if log_level: + print( + f'Dataset generated using MMLPs was written in "{mlp_eval_filename}".' + ) + phonon.save(mlp_eval_filename) + if settings.save_params: exists_fc_only = ( not forces_in_dataset(phonon.dataset) and phonon.force_constants is not None @@ -566,6 +574,41 @@ def _create_FORCE_SETS_from_settings( ) +def _produce_force_constants_load_phonopy_yaml( + phonon: Phonopy, + settings: Settings, + phpy_yaml: PhonopyYaml, + unitcell_filename: str, + log_level: int, +): + is_full_fc = settings.fc_spg_symmetry or settings.is_full_fc + (fc_calculator, fc_calculator_options) = _get_fc_calculator_params(settings) + + try: + set_dataset_and_force_constants( + phonon, + phpy_yaml.dataset, + phonopy_yaml_filename=unitcell_filename, + fc=phpy_yaml.force_constants, + fc_calculator=fc_calculator, + fc_calculator_options=fc_calculator_options, + produce_fc=True, + symmetrize_fc=False, + is_compact_fc=(not is_full_fc), + use_pypolymlp=settings.use_pypolymlp, + mlp_params=settings.mlp_params, + displacement_distance=settings.displacement_distance, + number_of_snapshots=settings.random_displacements, + random_seed=settings.random_seed, + log_level=log_level, + ) + except (RuntimeError, ValueError) as e: + print_error_message(str(e)) + if log_level: + print_error() + sys.exit(1) + + def _produce_force_constants( phonon: Phonopy, settings: Settings, @@ -601,7 +644,7 @@ def _produce_force_constants( force_sets = read_force_sets_from_phonopy_yaml(phpy_yaml) if log_level: if force_sets is None: - print('Force sets were not found in "%s".' % unitcell_filename) + print(f'Force sets were not found in "{unitcell_filename}".') else: print( 'Forces and displacements were read from "%s".' @@ -748,69 +791,9 @@ def _store_force_constants( p2s_map = phonon.primitive.p2s_map if load_phonopy_yaml: - is_full_fc = settings.fc_spg_symmetry or settings.is_full_fc - if phpy_yaml.force_constants is not None: - if log_level: - print('Force constants were read from "%s".' % unitcell_filename) - - fc = phpy_yaml.force_constants - - if fc.shape[1] != len(phonon.supercell): - error_text = ( - "Number of atoms in supercell is not consistent " - "with the matrix shape of\nforce constants read " - "from %s." % unitcell_filename - ) - print_error_message(error_text) - if log_level: - print_error() - sys.exit(1) - - # Compact fc is expanded to full fc when full fc is required. - if is_full_fc and fc.shape[0] != fc.shape[1]: - fc = compact_fc_to_full_fc(phonon, fc, log_level=log_level) - elif not is_full_fc and fc.shape[0] == fc.shape[1]: - fc = full_fc_to_compact_fc(phonon, fc, log_level=log_level) - - phonon.set_force_constants(fc, show_drift=(log_level > 0)) - - if settings.read_force_constants: - _read_force_constants_from_file( - settings, phonon, unitcell_filename, is_full_fc, log_level - ) - else: - if forces_in_dataset(phpy_yaml.dataset): - if log_level: - text = 'Force sets were read from "%s"' % unitcell_filename - if phpy_yaml.force_constants is not None: - text += " but not to be used." - else: - text += "." - print(text) - - if phpy_yaml.force_constants is None: - (fc_calculator, fc_calculator_options) = _get_fc_calculator_params( - settings - ) - - try: - set_dataset_and_force_constants( - phonon, - phpy_yaml.dataset, - None, - fc_calculator=fc_calculator, - fc_calculator_options=fc_calculator_options, - produce_fc=True, - symmetrize_fc=False, - is_compact_fc=(not is_full_fc), - use_pypolymlp=settings.use_pypolymlp, - log_level=log_level, - ) - except (RuntimeError, ValueError) as e: - print_error_message(str(e)) - if log_level: - print_error() - sys.exit(1) + _produce_force_constants_load_phonopy_yaml( + phonon, settings, phpy_yaml, unitcell_filename, log_level + ) else: _produce_force_constants( phonon, settings, phpy_yaml, unitcell_filename, log_level @@ -874,6 +857,64 @@ def _store_force_constants( return True +def _create_random_displacements_at_finite_temperature( + phonon: Phonopy, + settings: Settings, + confs: dict, + optional_structure_info: Optional[tuple], + log_level: int, +): + if ( + settings.random_displacements + and settings.random_displacement_temperature is not None + ): + if file_exists("phonopy_disp.yaml", log_level=log_level, is_any=True): + if log_level: + print( + '"phonopy_disp.yaml" is already existing in the current directory.' + ) + print('Please rename it not to lose "phonopy_disp.yaml".') + print_error() + sys.exit(1) + + phonon.generate_displacements( + number_of_snapshots=settings.random_displacements, + random_seed=settings.random_seed, + temperature=settings.random_displacement_temperature, + cutoff_frequency=settings.cutoff_frequency, + ) + + if log_level: + rd_comm_points = phonon.random_displacements.qpoints + rd_integrated_modes = phonon.random_displacements.integrated_modes + rd_frequencies = phonon.random_displacements.frequencies + print( + "Sampled q-points for generating displacements " + "(number of integrated modes):" + ) + for q, integrated_modes, freqs in zip( + rd_comm_points, rd_integrated_modes, rd_frequencies + ): + print(f"{q} ({integrated_modes.sum()})") + if log_level > 1: + print(" ", " ".join([f"{f:.3f}" for f in freqs])) + if np.prod(rd_integrated_modes.shape) - rd_integrated_modes.sum() != 3: + msg_lines = [ + "*****************************************************************", + "* Tiny frequencies can induce unexpectedly large displacements. *", + "* Please check force constants symmetry, e.g., --sym-fc option. *", + "*****************************************************************", + ] + print("\n".join(msg_lines)) + if log_level < 2: + print('Phonon frequencies can be shown by "-v" option.') + print() + + _write_displacements_files_then_exit( + phonon, settings, confs, optional_structure_info, log_level + ) + + def store_nac_params( phonon, settings, @@ -2037,25 +2078,28 @@ def main(**argparse_control): ######################################################### # Create constant amplitude displacements and then exit # ######################################################### - if ( - settings.create_displacements or settings.random_displacements - ) and settings.random_displacement_temperature is None: - if settings.displacement_distance is None: - displacement_distance = get_default_displacement_distance(phonon.calculator) - else: - displacement_distance = settings.displacement_distance + if not settings.use_pypolymlp: + if ( + settings.create_displacements or settings.random_displacements + ) and settings.random_displacement_temperature is None: + if settings.displacement_distance is None: + displacement_distance = get_default_displacement_distance( + phonon.calculator + ) + else: + displacement_distance = settings.displacement_distance - phonon.generate_displacements( - distance=displacement_distance, - is_plusminus=settings.is_plusminus_displacement, - is_diagonal=settings.is_diagonal_displacement, - is_trigonal=settings.is_trigonal_displacement, - number_of_snapshots=settings.random_displacements, - random_seed=settings.random_seed, - ) - _write_displacements_files_then_exit( - phonon, settings, confs, cell_info["optional_structure_info"], log_level - ) + phonon.generate_displacements( + distance=displacement_distance, + is_plusminus=settings.is_plusminus_displacement, + is_diagonal=settings.is_diagonal_displacement, + is_trigonal=settings.is_trigonal_displacement, + number_of_snapshots=settings.random_displacements, + random_seed=settings.random_seed, + ) + _write_displacements_files_then_exit( + phonon, settings, confs, cell_info["optional_structure_info"], log_level + ) ################### # Force constants # @@ -2075,53 +2119,8 @@ def main(**argparse_control): ################################################################### # Create random displacements at finite temperature and then exit # ################################################################### - if ( - settings.random_displacements - and settings.random_displacement_temperature is not None - ): - if file_exists("phonopy_disp.yaml", log_level=log_level, is_any=True): - if log_level: - print( - '"phonopy_disp.yaml" is already existing in the current directory.' - ) - print('Please rename it not to lose "phonopy_disp.yaml".') - print_error() - sys.exit(1) - - phonon.generate_displacements( - number_of_snapshots=settings.random_displacements, - random_seed=settings.random_seed, - temperature=settings.random_displacement_temperature, - cutoff_frequency=settings.cutoff_frequency, - ) - - if log_level: - rd_comm_points = phonon.random_displacements.qpoints - rd_integrated_modes = phonon.random_displacements.integrated_modes - rd_frequencies = phonon.random_displacements.frequencies - print( - "Sampled q-points for generating displacements " - "(number of integrated modes):" - ) - for q, integrated_modes, freqs in zip( - rd_comm_points, rd_integrated_modes, rd_frequencies - ): - print(f"{q} ({integrated_modes.sum()})") - if log_level > 1: - print(" ", " ".join([f"{f:.3f}" for f in freqs])) - if np.prod(rd_integrated_modes.shape) - rd_integrated_modes.sum() != 3: - msg_lines = [ - "*****************************************************************", - "* Tiny frequencies can induce unexpectedly large displacements. *", - "* Please check force constants symmetry, e.g., --sym-fc option. *", - "*****************************************************************", - ] - print("\n".join(msg_lines)) - if log_level < 2: - print('Phonon frequencies can be shown by "-v" option.') - print() - - _write_displacements_files_then_exit( + if not settings.use_pypolymlp: + _create_random_displacements_at_finite_temperature( phonon, settings, confs, cell_info["optional_structure_info"], log_level ) diff --git a/phonopy/interface/pypolymlp.py b/phonopy/interface/pypolymlp.py index cc1af879..370fe1fd 100644 --- a/phonopy/interface/pypolymlp.py +++ b/phonopy/interface/pypolymlp.py @@ -119,7 +119,7 @@ def develop_polymlp( test_data: PypolymlpData, params: Optional[PypolymlpParams] = None, verbose: bool = False, -): +) -> Pypolymlp: # type: ignore """Develop polynomial MLPs of pypolymlp. Parameters @@ -287,3 +287,57 @@ def load_polymlp(filename: str) -> Pypolymlp: # type: ignore mlp = Pypolymlp() mlp.load_mlp(filename=filename) return mlp + + +def develop_mlp_by_pypolymlp( + mlp_dataset: dict, + supercell: PhonopyAtoms, + params: Optional[Union[PypolymlpParams, dict, str]] = None, + test_size: float = 0.1, + log_level: int = 0, +) -> Pypolymlp: # type: ignore + """Develop MLPs by pypolymlp.""" + if params is not None: + _params = parse_mlp_params(params) + else: + _params = params + + if _params is not None and _params.ntrain is not None and _params.ntest is not None: + ntrain = _params.ntrain + ntest = _params.ntest + disps = mlp_dataset["displacements"] + forces = mlp_dataset["forces"] + energies = mlp_dataset["supercell_energies"] + train_data = PypolymlpData( + displacements=disps[:ntrain], + forces=forces[:ntrain], + supercell_energies=energies[:ntrain], + ) + test_data = PypolymlpData( + displacements=disps[-ntest:], + forces=forces[-ntest:], + supercell_energies=energies[-ntest:], + ) + else: + disps = mlp_dataset["displacements"] + forces = mlp_dataset["forces"] + energies = mlp_dataset["supercell_energies"] + n = int(len(disps) * (1 - test_size)) + train_data = PypolymlpData( + displacements=disps[:n], + forces=forces[:n], + supercell_energies=energies[:n], + ) + test_data = PypolymlpData( + displacements=disps[n:], + forces=forces[n:], + supercell_energies=energies[n:], + ) + mlp = develop_polymlp( + supercell, + train_data, + test_data, + params=_params, + verbose=log_level - 1 > 0, + ) + return mlp