Feature: Option to set the tracking URI for MLflowCallback. (#29032)

* Added option to set tracking URI for MLflowCallback.

* Added option to set tracking URI for MLflowCallback.

* Changed  to  in docstring.
This commit is contained in:
Sean (Seok-Won) Yi 2024-02-16 23:47:18 +09:00 committed by GitHub
parent be42c24d14
commit 161fe425c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 1 deletions

View File

@ -959,6 +959,9 @@ class MLflowCallback(TrainerCallback):
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
storage will just copy the files to your artifact location.
- **MLFLOW_TRACKING_URI** (`str`, *optional*, defaults to `""`):
Whether to store runs at a specific path or remote server. Default to an empty string which will store runs
at `./mlruns` locally.
- **MLFLOW_EXPERIMENT_NAME** (`str`, *optional*, defaults to `None`):
Whether to use an MLflow experiment_name under which to launch the run. Default to `None` which will point
to the `Default` experiment in MLflow. Otherwise, it is a case sensitive name of the experiment to be
@ -978,14 +981,22 @@ class MLflowCallback(TrainerCallback):
"""
self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
logger.debug(
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
f" tags={self._nested_run}"
f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"
)
if state.is_world_process_zero:
self._ml_flow.set_tracking_uri(self._tracking_uri)
if self._tracking_uri == "":
logger.debug(f"MLflow tracking URI is not set. Runs will be stored at {os.path.realpath('./mlruns')}")
else:
logger.debug(f"MLflow tracking URI is set to {self._tracking_uri}")
if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
if self._experiment_name:
# Use of set_experiment() ensure that Experiment is created if not exists