113 lines
2.9 KiB
Python
113 lines
2.9 KiB
Python
"""Pydantic model for default configuration and validation."""
|
|
|
|
import subprocess
|
|
from typing import Optional, Union
|
|
import os
|
|
from pydantic import root_validator
|
|
|
|
# vfrom pydantic import Field, root_validator, validator
|
|
from pydantic.typing import Literal
|
|
|
|
# from typing import List
|
|
from models.base import BaseSettings
|
|
from models.potnet import PotNetConfig
|
|
|
|
try:
|
|
VERSION = (
|
|
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
|
|
)
|
|
except Exception as exp:
|
|
VERSION = "NA"
|
|
pass
|
|
|
|
|
|
FEATURESET_SIZE = {"basic": 11, "atomic_number": 1, "cfid": 438, "cgcnn": 92}
|
|
|
|
|
|
TARGET_ENUM = Literal[
|
|
"formation_energy_peratom",
|
|
"optb88vdw_bandgap",
|
|
"bulk_modulus_kv",
|
|
"shear_modulus_gv",
|
|
"mbj_bandgap",
|
|
"optb88vdw_total_energy",
|
|
"ehull",
|
|
"gap pbe",
|
|
"e_form",
|
|
"e_hull",
|
|
"formation_energy_per_atom",
|
|
"band_gap",
|
|
"bulk modulus",
|
|
"shear modulus",
|
|
"energy_per_atom",
|
|
"target",
|
|
]
|
|
|
|
|
|
class TrainingConfig(BaseSettings):
|
|
"""Training config defaults and validation."""
|
|
|
|
version: str = VERSION
|
|
|
|
# dataset configuration
|
|
dataset: Literal[
|
|
"dft_3d",
|
|
"megnet",
|
|
] = "dft_3d"
|
|
target: TARGET_ENUM = "formation_energy_peratom"
|
|
atom_features: Literal["basic", "atomic_number", "cfid", "cgcnn"] = "cgcnn"
|
|
id_tag: Literal["jid", "id"] = "jid"
|
|
|
|
# logging configuration
|
|
|
|
# training configuration
|
|
random_seed: Optional[int] = 123
|
|
n_val: Optional[int] = None
|
|
n_test: Optional[int] = None
|
|
n_train: Optional[int] = None
|
|
train_ratio: Optional[float] = 0.8
|
|
val_ratio: Optional[float] = 0.1
|
|
test_ratio: Optional[float] = 0.1
|
|
epochs: int = 500
|
|
batch_size: int = 64
|
|
weight_decay: float = 0.0
|
|
learning_rate: float = 1e-3
|
|
warmup_steps: int = 2000
|
|
criterion: Literal["mse", "l1", "poisson"] = "mse"
|
|
optimizer: Literal["adamw", "sgd"] = "adamw"
|
|
scheduler: Literal["onecycle", "step", "none"] = "onecycle"
|
|
pin_memory: bool = False
|
|
write_checkpoint: bool = True
|
|
write_predictions: bool = True
|
|
store_outputs: bool = True
|
|
progress: bool = True
|
|
log_tensorboard: bool = False
|
|
num_workers: int = 8
|
|
normalize: bool = False
|
|
euclidean: bool = False
|
|
keep_data_order: bool = False
|
|
cutoff: float = 8.0
|
|
max_neighbors: int = 12
|
|
infinite_funcs = ["zeta", "zeta", "exp"]
|
|
infinite_params = [3.0, 0.5, 3.0]
|
|
R: int = 5
|
|
n_early_stopping: Optional[int] = None
|
|
output_dir: str = ""
|
|
process_dir: str = "processed"
|
|
cache_dir: str = "cache"
|
|
checkpoint_dir: str = "checkpoints"
|
|
|
|
# model configuration
|
|
model: Union[
|
|
PotNetConfig,
|
|
] = PotNetConfig(name="potnet")
|
|
|
|
@root_validator()
|
|
def set_input_size(cls, values):
|
|
"""Automatically configure node feature dimensionality."""
|
|
print(values)
|
|
values["model"].atom_input_features = FEATURESET_SIZE[
|
|
values["atom_features"]
|
|
]
|
|
|
|
return values |