From 9b23427cbf162433e51ca6c65364a8a42e27cda4 Mon Sep 17 00:00:00 2001 From: Atsushi Togo Date: Thu, 5 Sep 2024 16:38:13 +0900 Subject: [PATCH] Separate detail of Phono3py.develop_mlp() into a function in phonopy --- phono3py/api_phono3py.py | 51 +++++----------------------------------- 1 file changed, 6 insertions(+), 45 deletions(-) diff --git a/phono3py/api_phono3py.py b/phono3py/api_phono3py.py index b42d4a16..c1d8da07 100644 --- a/phono3py/api_phono3py.py +++ b/phono3py/api_phono3py.py @@ -58,6 +58,7 @@ from phonopy.interface.fc_calculator import get_fc2 from phonopy.interface.pypolymlp import ( PypolymlpData, PypolymlpParams, + develop_mlp_by_pypolymlp, develop_polymlp, evalulate_polymlp, load_polymlp, @@ -2204,52 +2205,12 @@ class Phono3py: if self._mlp_dataset is None: raise RuntimeError("MLP dataset is not set.") - if params is not None: - _params = parse_mlp_params(params) - else: - _params = params - - if ( - _params is not None - and _params.ntrain is not None - and _params.ntest is not None - ): - ntrain = _params.ntrain - ntest = _params.ntest - disps = self._mlp_dataset["displacements"] - forces = self._mlp_dataset["forces"] - energies = self._mlp_dataset["supercell_energies"] - train_data = PypolymlpData( - displacements=disps[:ntrain], - forces=forces[:ntrain], - supercell_energies=energies[:ntrain], - ) - test_data = PypolymlpData( - displacements=disps[-ntest:], - forces=forces[-ntest:], - supercell_energies=energies[-ntest:], - ) - else: - disps = self._mlp_dataset["displacements"] - forces = self._mlp_dataset["forces"] - energies = self._mlp_dataset["supercell_energies"] - n = int(len(disps) * (1 - test_size)) - train_data = PypolymlpData( - displacements=disps[:n], - forces=forces[:n], - supercell_energies=energies[:n], - ) - test_data = PypolymlpData( - displacements=disps[n:], - forces=forces[n:], - supercell_energies=energies[n:], - ) - self._mlp = develop_polymlp( + self._mlp = develop_mlp_by_pypolymlp( + self._mlp_dataset, self._supercell, - train_data, - test_data, - params=_params, - verbose=self._log_level - 1 > 0, + params=params, + test_size=test_size, + log_level=self._log_level, ) def load_mlp(self, filename: str = "phono3py.pmlp"):