Source code for nemo_gym.gitlab_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from os import environ
from pathlib import Path

import requests
from mlflow import MlflowClient
from mlflow.artifacts import get_artifact_repository
from mlflow.environment_variables import MLFLOW_TRACKING_TOKEN
from mlflow.exceptions import RestException
from pydantic import BaseModel

from nemo_gym.config_types import (
    DownloadJsonlDatasetGitlabConfig,
    UploadJsonlDatasetGitlabConfig,
)
from nemo_gym.server_utils import get_global_config_dict


[docs] class MLFlowConfig(BaseModel): mlflow_tracking_uri: str mlflow_tracking_token: str
[docs] def create_mlflow_client() -> MlflowClient: # pragma: no cover global_config = get_global_config_dict() config = MLFlowConfig.model_validate(global_config) environ["MLFLOW_TRACKING_TOKEN"] = config.mlflow_tracking_token client = MlflowClient(tracking_uri=config.mlflow_tracking_uri) return client
[docs] def upload_jsonl_dataset( config: UploadJsonlDatasetGitlabConfig, ) -> None: # pragma: no cover client = create_mlflow_client() try: client.create_registered_model(config.dataset_name) except RestException: pass tags = {"gitlab.version": config.version} try: model_version = client.get_model_version(config.dataset_name, config.version) except RestException: model_version = client.create_model_version(config.dataset_name, config.version, tags=tags) run_id = model_version.run_id client.log_artifact(run_id, config.input_jsonl_fpath, artifact_path="") filename = Path(config.input_jsonl_fpath).name DownloadJsonlDatasetGitlabConfig print(f"""Download this artifact: ng_download_dataset_from_gitlab \\ +dataset_name={config.dataset_name} \\ +version={config.version} \\ +artifact_fpath={filename} \\ +output_fpath={config.input_jsonl_fpath} """)
[docs] def upload_jsonl_dataset_cli() -> None: # pragma: no cover global_config = get_global_config_dict() config = UploadJsonlDatasetGitlabConfig.model_validate(global_config) upload_jsonl_dataset(config)
[docs] def download_jsonl_dataset( config: DownloadJsonlDatasetGitlabConfig, ) -> None: # pragma: no cover # TODO: There is probably a much better way to do this, but it is not clear at the moment. client = create_mlflow_client() model_version = client.get_model_version(config.dataset_name, config.version) run_id = model_version.run_id repo = get_artifact_repository(artifact_uri=f"runs:/{run_id}", tracking_uri=client.tracking_uri) artifact_uri = repo.repo.artifact_uri download_link = f"{artifact_uri.rstrip('/')}/{config.artifact_fpath.lstrip('/')}" response = requests.get( download_link, headers={"Authorization": f"Bearer {MLFLOW_TRACKING_TOKEN.get()}"}, ) with open(config.output_fpath, "w") as f: f.write(response.content.decode())
[docs] def download_jsonl_dataset_cli() -> None: # pragma: no cover global_config = get_global_config_dict() config = DownloadJsonlDatasetGitlabConfig.model_validate(global_config) download_jsonl_dataset(config)
[docs] def is_model_in_gitlab(model_name: str) -> bool: # pragma: no cover client = create_mlflow_client() # model_name in gitlab is case sensitive try: client.get_registered_model(model_name) except RestException as e: print(f"[Nemo-Gym] - Model '{model_name}' not found in Gitlab: {e}") return False return True
[docs] def delete_model_from_gitlab(model_name: str) -> None: # pragma: no cover client = create_mlflow_client() client.delete_registered_model(model_name)