Source code for nemo_gym.train_data_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from abc import abstractmethod
from collections import Counter, defaultdict
from math import sqrt
from pathlib import Path
from shutil import copyfileobj
from typing import Any, Dict, List, Literal, Optional, Self, Tuple, Union

from devtools import pprint
from omegaconf import DictConfig
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from tdigest import TDigest
from tqdm.auto import tqdm

from nemo_gym.base_resources_server import BaseRunRequest
from nemo_gym.config_types import (
    AGENT_REF_KEY,
    AgentServerRef,
    BaseNeMoGymCLIConfig,
    DatasetConfig,
    DatasetType,
    DownloadJsonlDatasetGitlabConfig,
    DownloadJsonlDatasetHuggingFaceConfig,
    ServerInstanceConfig,
)
from nemo_gym.gitlab_utils import download_jsonl_dataset
from nemo_gym.global_config import (
    GlobalConfigDictParser,
    GlobalConfigDictParserConfig,
    get_global_config_dict,
)
from nemo_gym.hf_utils import (
    download_hf_dataset_as_jsonl,
)


[docs] class TrainDataProcessorConfig(BaseNeMoGymCLIConfig): """ Prepare and validate training data, generating metrics and statistics for datasets. Examples: ```bash config_paths="resources_servers/example_multi_step/configs/example_multi_step.yaml,\\ responses_api_models/openai_model/configs/openai_model.yaml" ng_prepare_data "+config_paths=[${config_paths}]" \ +output_dirpath=data/example_multi_step \ +mode=example_validation ``` """ output_dirpath: str = Field(description="Directory path where processed datasets and metrics will be saved.") mode: Union[Literal["train_preparation"], Literal["example_validation"]] = Field( description="Processing mode: 'train_preparation' prepares train/validation datasets for training, 'example_validation' validates example data for PR submission." ) should_download: bool = Field( default=False, description="Whether to automatically download missing datasets from remote registries (default: False).", ) data_source: Literal["gitlab", "huggingface"] = Field( default="huggingface", description="Where to download missing datasets from: 'gitlab' (NVIDIA internal) or 'huggingface' (external).", ) @property def in_scope_dataset_types(self) -> List[DatasetType]: if self.mode == "train_preparation": return ["train", "validation"] elif self.mode == "example_validation": return ["example"] else: raise NotImplementedError
[docs] class Accumulator(BaseModel): is_aggregated: bool = Field(default=False, exclude=True)
[docs] def add(self: Self, other: Self) -> None: assert not self.is_aggregated assert not other.is_aggregated self._add(other)
[docs] @abstractmethod def _add(self: Self, other: Self) -> None: pass
[docs] def aggregate(self: Self) -> Self: res = self._aggregate() res.is_aggregated = True return res
[docs] @abstractmethod def _aggregate(self: Self) -> Self: pass
[docs] class AvgMinMax(Accumulator): model_config = ConfigDict(arbitrary_types_allowed=True) total: int = Field(serialization_alias="Total # non-null values", default=0) average: float = Field(serialization_alias="Average", default=0) min: float = Field(serialization_alias="Min", default=float("inf")) max: float = Field(serialization_alias="Max", default=float("-inf")) median: float = Field(serialization_alias="Median", default=0) stddev: float = Field(serialization_alias="Standard deviation", default=0) # Internal state mean: float = Field(default=0, exclude=True) # running value (before final average) M2: float = Field(default=0, exclude=True) # sum of squared differences (for variance) tdigest: TDigest = Field(default_factory=TDigest, exclude=True) """ T-Digest is used to estimate the Median without storing and sorting all values. The Median is essentially an approximation using the 50th percentile, which is very close to the true Median. """
[docs] def observe(self, x: float) -> None: if x < self.min: self.min = x if x > self.max: self.max = x # Update running mean and variance self.total += 1 delta = x - self.mean self.mean += delta / self.total self.M2 += delta * (x - self.mean) # Update quantile estimator (for median) self.tdigest.update(x)
[docs] def _add(self: Self, other: Self) -> None: # Merge accumulators if other.total == 0: return if self.total == 0: self.total = other.total self.mean = other.mean self.M2 = other.M2 self.min = other.min self.max = other.max self.tdigest = TDigest() self.tdigest = self.tdigest + other.tdigest return # Merge mean and variance n1, n2 = self.total, other.total delta = other.mean - self.mean n = n1 + n2 self.mean = self.mean + delta * (n2 / n) self.M2 = self.M2 + other.M2 + (delta * delta) * (n1 * n2 / n) self.total = n if other.min < self.min: self.min = other.min if other.max > self.max: self.max = other.max # Merge t-digests for quantiles/median self.tdigest = self.tdigest + other.tdigest
[docs] def _aggregate(self: Self) -> Self: def round_metric(x: float) -> float: if x >= 1 or x <= -1: return round(x, 2) return round(x, 3) n = self.total mean = self.mean if n > 0 else 0.0 stddev = sqrt(self.M2 / (n - 1)) if n > 1 else 0.0 med = float(self.tdigest.percentile(50)) if n > 0 and self.tdigest.n > 0 else 0.0 params = { "total": self.total, "average": mean, "min": self.min if n > 0 else 0.0, "max": self.max if n > 0 else 0.0, "median": med, "stddev": stddev, } final_params = {k: round_metric(v) if isinstance(v, float) else v for k, v in params.items()} return AvgMinMax(**final_params)
[docs] class StringMetrics(BaseModel): unique_count: int total_count: int
[docs] class DatasetMetrics(Accumulator): model_config = ConfigDict(extra="allow") # Allow any arbitrary fields number_of_examples: int = Field(serialization_alias="Number of examples", default=0) number_of_tools: AvgMinMax = Field(serialization_alias="Number of tools", default_factory=AvgMinMax) json_dumped_number_of_words: AvgMinMax = Field( serialization_alias="Json-dumped number of words (proxy for token count)", default_factory=AvgMinMax, ) number_of_turns: AvgMinMax = Field(serialization_alias="Number of turns", default_factory=AvgMinMax) temperature: AvgMinMax = Field(serialization_alias="Temperature", default_factory=AvgMinMax) # TODO: Number of unique create params, Number of unique user messages, other sampling params, etc
[docs] def _add(self: Self, other: Self) -> None: self.number_of_examples += other.number_of_examples self.number_of_tools.add(other.number_of_tools) self.json_dumped_number_of_words.add(other.json_dumped_number_of_words) self.number_of_turns.add(other.number_of_turns) self.temperature.add(other.temperature) # Merge extra fields safely if other.model_extra: for k, v in other.model_extra.items(): if k in DatasetMetrics.model_fields.keys(): continue setattr(self, k, v)
[docs] def _aggregate(self: Self) -> Self: extras = {} if self.model_extra: for k, v in self.model_extra.items(): if k in DatasetMetrics.model_fields.keys(): continue extras[k] = v return DatasetMetrics( number_of_examples=self.number_of_examples, number_of_tools=self.number_of_tools.aggregate(), json_dumped_number_of_words=self.json_dumped_number_of_words.aggregate(), number_of_turns=self.number_of_turns.aggregate(), temperature=self.temperature.aggregate(), **extras, )
[docs] def aggregate_other_metrics(metrics: Dict[str, Any], sample: Dict[str, Any]) -> None: """Combines misc items (those other than response/response create params) into current metrics""" for k, v in sample.items(): if k in ("responses_create_params", "response"): continue values = v if isinstance(v, list) else [v] for item in values: if isinstance(item, bool): item = int(item) if isinstance(item, (int, float)): if k not in metrics: metrics[k] = AvgMinMax() metrics[k].observe(item) elif isinstance(item, str): if k not in metrics: metrics[k] = Counter() metrics[k][item] += 1
[docs] def postprocess_other_metrics(metrics: DatasetMetrics, other_metrics: Dict[str, Any]) -> None: """Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics""" for k, v in other_metrics.items(): if isinstance(v, AvgMinMax): setattr(metrics, k, v.aggregate()) elif isinstance(v, Counter): setattr(metrics, k, StringMetrics(unique_count=len(v), total_count=sum(v.values())))
[docs] def compute_sample_metrics(sample_dict_str: str) -> Tuple[DatasetMetrics, bool]: try: sample_dict = json.loads(sample_dict_str) except json.JSONDecodeError: return DatasetMetrics(), True try: sample = BaseRunRequest.model_validate(sample_dict) except ValidationError: return DatasetMetrics(), True responses_create_params = sample.responses_create_params responses_create_params = responses_create_params.model_dump(exclude_unset=True) inputs = responses_create_params.get("input") number_of_tools_metrics = AvgMinMax() if responses_create_params.get("tools") is not None: number_of_tools = len(responses_create_params["tools"]) number_of_tools_metrics.observe(number_of_tools) if isinstance(inputs, str): inputs = [{"role": "user", "content": inputs}] user_inputs = [i for i in inputs if i.get("role") == "user"] if inputs else [] number_of_turns_metrics = AvgMinMax() if user_inputs: number_of_turns = len(user_inputs) number_of_turns_metrics.observe(number_of_turns) temperature_metrics = AvgMinMax() if responses_create_params.get("temperature") is not None: temperature = responses_create_params["temperature"] temperature_metrics.observe(temperature) json_dumped_number_of_words_metrics = AvgMinMax() json_dumped_number_of_words = len(json.dumps(responses_create_params).split()) json_dumped_number_of_words_metrics.observe(json_dumped_number_of_words) metrics = DatasetMetrics( number_of_examples=1, number_of_tools=number_of_tools_metrics, json_dumped_number_of_words=json_dumped_number_of_words_metrics, number_of_turns=number_of_turns_metrics, temperature=temperature_metrics, ) return metrics, False
[docs] class DatasetValidatorState(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) metrics: DatasetMetrics = Field(default_factory=DatasetMetrics) key_counts: Counter = Field(default_factory=Counter) offending_example_idxs: List[int] = Field(default_factory=list) other_metrics: Dict[str, Any] = Field(default_factory=dict)
[docs] class TrainDataProcessor(BaseModel):
[docs] def run(self, global_config_dict: DictConfig): # pragma: no cover """ See the README section "How To: Prepare and validate data for PR submission or RL training" """ config = TrainDataProcessorConfig.model_validate(global_config_dict) self._print_title("Load and validate server instance configs") server_instance_configs = self.load_and_validate_server_instance_configs(config, global_config_dict) self._print_title( f"Load datasets. Missing datasets {'**WILL**' if config.should_download else 'will **NOT**'} be downloaded." ) self.load_datasets(config, server_instance_configs) self._print_title("Validate samples and aggregate metrics") dataset_type_to_aggregate_metrics = self.validate_samples_and_aggregate_metrics(server_instance_configs) self._print_title("Collate samples and aggregate metrics") self.collate_samples(config, server_instance_configs, dataset_type_to_aggregate_metrics) self._print_title("Finished!")
[docs] def _print_title(self, title: str) -> None: # pragma: no cover print(f""" {"#" * 100} # # {title} # {"#" * 100} """)
[docs] def load_and_validate_server_instance_configs( self, config: TrainDataProcessorConfig, global_config_dict: DictConfig ) -> List[ServerInstanceConfig]: parser = GlobalConfigDictParser() server_instance_configs = parser.filter_for_server_instance_configs(global_config_dict) agent_configs: List[ServerInstanceConfig] = [ c for c in server_instance_configs if c.SERVER_TYPE == "responses_api_agents" ] server_names_list_str = "\n- ".join([""] + [f"{c.name} ({c.SERVER_TYPE})" for c in server_instance_configs]) print( f"Found {len(server_instance_configs)} server instance configs ({len(agent_configs)} agent configs):{server_names_list_str}\n\n" ) agent_configs_with_data: List[ServerInstanceConfig] = [] agent_configs_without_data: List[ServerInstanceConfig] = [] for agent_config in agent_configs: if agent_config.datasets: agent_configs_with_data.append(agent_config) else: agent_configs_without_data.append(agent_config) server_names_list_str = "\n- ".join([""] + [f"{c.name} ({c.SERVER_TYPE})" for c in agent_configs_without_data]) print( f"Found {len(agent_configs_without_data)} agent server instance configs WITHOUT datasets:{server_names_list_str}\n\n" ) server_names_list_str = "" for c in agent_configs_with_data: server_str = f"\n- {c.name}" datasets_str = "\n - ".join([""] + [f"{d.name} ({d.type})" for d in c.datasets]) server_names_list_str += f"{server_str}{datasets_str}" print( f"Found {len(agent_configs_with_data)} agent server instance configs WITH datasets:{server_names_list_str}\n\n" ) # Filter for in scope depending on the mode. in_scope_dataset_types = config.in_scope_dataset_types agent_configs_with_in_scope_datasets: List[ServerInstanceConfig] = [] for agent_config in agent_configs_with_data: in_scope_datasets = [d for d in agent_config.datasets if d.type in in_scope_dataset_types] if not in_scope_datasets: continue inner_config = agent_config.get_inner_run_server_config() inner_config.datasets = in_scope_datasets agent_configs_with_in_scope_datasets.append(agent_config) server_names_list_str = "" for c in agent_configs_with_in_scope_datasets: server_str = f"\n- {c.name}" datasets_str = "\n - ".join([""] + [f"{d.name} ({d.type})" for d in c.datasets]) server_names_list_str += f"{server_str}{datasets_str}" print(f"In scope dataset types for `{config.mode}` mode: {in_scope_dataset_types}") print( f"Found {len(agent_configs_with_data)} agent server instance configs with in-scope datasets:{server_names_list_str}" ) return agent_configs_with_data
[docs] def load_datasets( self, config: TrainDataProcessorConfig, server_instance_configs: List[ServerInstanceConfig], ) -> None: # Check if all the dataset paths exist. Mapping of server name to list of dataset config local_datasets_found: Dict[str, List[DatasetConfig]] = defaultdict(list) local_datasets_not_found: Dict[str, List[DatasetConfig]] = defaultdict(list) for c in server_instance_configs: for d in c.datasets: jsonl_fpath = Path(d.jsonl_fpath) if jsonl_fpath.exists(): local_datasets_found[c.name].append(d) else: local_datasets_not_found[c.name].append(d) server_names_list_str = "" for server_name, datasets in local_datasets_found.items(): datasets_str = "\n - ".join([""] + [f"{d.name} ({d.type})" for d in datasets]) server_names_list_str += f"\n- {server_name}{datasets_str}" print(f"FOUND the following datasets at their local paths:{server_names_list_str}\n\n") server_names_list_str = "" for server_name, datasets in local_datasets_not_found.items(): datasets_str = "\n - ".join([""] + [f"{d.name} ({d.type})" for d in datasets]) server_names_list_str += f"\n- {server_name}{datasets_str}" print(f"MISSING the following datasets:{server_names_list_str}\n\n") if config.mode == "example_validation": assert not local_datasets_not_found, "You must provide the above missing example jsonl files!" if not config.should_download: assert not local_datasets_not_found, ( "Missing local datasets. You must provide local datasets since download is disabled. Run with `+should_download=true` to enable downloading." ) if not local_datasets_not_found: return backend = config.data_source is_valid, error_msg = validate_backend_credentials(backend) global_config = get_global_config_dict() if not is_valid: print(f"Cannot download datasets: {error_msg}") sys.exit(1) for ( server_name, datasets, ) in local_datasets_not_found.items(): # pragma: no cover for d in datasets: try: if backend == "gitlab": if d.gitlab_identifier is None: print(f"Dataset `{d.name}` missing gitlab_identifier for GitLab backend") continue download_config = DownloadJsonlDatasetGitlabConfig.model_validate( d.gitlab_identifier.model_dump() | {"output_fpath": d.jsonl_fpath} ) print( f"Downloading dataset `{d.name}` for `{server_name}` from {backend} using {download_config}" ) download_jsonl_dataset(download_config) elif backend == "huggingface": hf_identifier = d.huggingface_identifier if hf_identifier is None: print(f"Dataset `{d.name}` missing huggingface_identifier for HuggingFace backend") continue download_config = DownloadJsonlDatasetHuggingFaceConfig.model_validate( { "repo_id": hf_identifier.repo_id, "artifact_fpath": hf_identifier.artifact_fpath, "output_fpath": d.jsonl_fpath, # Only pass split if artifact_fpath is not set **({"split": d.type} if not hf_identifier.artifact_fpath else {}), "hf_token": global_config.get("hf_token"), } ) print(f"Downloading '{d.type}' split from {hf_identifier.repo_id} to {d.jsonl_fpath}...") download_hf_dataset_as_jsonl(download_config) except Exception as e: print(f"Failed to download dataset `{d.name}` from {backend}: {e}")
######################################## # Validate samples and aggregate metrics ########################################
[docs] def _validate_samples_and_aggregate_metrics_single_sample( self, state: DatasetValidatorState, sample_idx: int, sample_dict_str: str ) -> None: metrics, is_offending = compute_sample_metrics(sample_dict_str) if is_offending: state.offending_example_idxs.append(sample_idx) return sample_dict = json.loads(sample_dict_str) state.key_counts.update(sample_dict.keys()) state.metrics.add(metrics) aggregate_other_metrics(state.other_metrics, sample_dict)
[docs] def _iter_dataset_lines(self, dataset_config: DatasetConfig): repeats = dataset_config.num_repeats # Print dataset repetition info if repeats > 1: print( f"Dataset {dataset_config.name} repeating {repeats}x: each line repeated {repeats} times (e.g. pattern: abc -> aaaabbbbcccc)" ) # Don't load everything into memory at once. Throw things away immediately. with open(dataset_config.jsonl_fpath) as f: for line in tqdm(f, desc=f"{dataset_config.jsonl_fpath}"): for _ in range(repeats): yield line
[docs] def _validate_samples_and_aggregate_metrics_single_dataset( self, dataset_config: DatasetConfig ) -> DatasetValidatorState: state = DatasetValidatorState() map_fn = self._validate_samples_and_aggregate_metrics_single_sample for sample_idx, sample_dict_str in enumerate(self._iter_dataset_lines(dataset_config)): map_fn(state, sample_idx, sample_dict_str) postprocess_other_metrics(state.metrics, state.other_metrics) return state
[docs] def _validate_aggregate_metrics(self, aggregate_metrics_dict: Dict, metrics_fpath: Path) -> Optional[Path]: """ Returns the conflicting metrics fpath if invalid. Else returns None """ if not metrics_fpath.exists(): return with open(metrics_fpath) as f: previous_aggregate_metrics_dict = json.load(f) def numeric_close(a: float, b: float) -> bool: """Helper to compare numbers with a tolerance""" if a == b: return True try: a_f = float(a) b_f = float(b) except Exception: return False scale = max(abs(a_f), abs(b_f)) # Adjuster for tolerance # may need to adjust this threshold: tol = 5e-3 if scale >= 1 else 5e-4 # Higher threshold for larger numbers return abs(a_f - b_f) <= max(tol, 1e-9) # Allow small differences def diff_values(prev_v, new_v, path: str, diffs: List[str]) -> None: """ Recursively compare values at the given path. Keys from previous dict must be present in new dict. Additional fields in new dict are allowed. """ if isinstance(prev_v, dict) and isinstance(new_v, dict): for k in prev_v.keys(): sub_path = f"{path}.{k}" if path else k if k not in new_v: diffs.append(f"Missing key in new metrics: {sub_path}") continue diff_values(prev_v[k], new_v[k], sub_path, diffs) return # Lists: Check for equality regardless of order if isinstance(prev_v, list) and isinstance(new_v, list): if len(prev_v) != len(new_v): diffs.append(f"List length differs at {path}: {len(prev_v)} != {len(new_v)}") return try: prev_counter = Counter(prev_v) new_counter = Counter(new_v) if prev_counter != new_counter: diffs.append(f"Multiset mismatch at {path}: {prev_counter} != {new_counter}") return except TypeError: # Manual fallback for unhashable elements used = set() for i, pv in enumerate(prev_v): found = False for j, nv in enumerate(new_v): if j in used: continue sub_diffs = [] diff_values(pv, nv, f"{path}[{i}]", sub_diffs) if not sub_diffs: used.add(j) found = True break if not found: diffs.append(f"No matching element for {path}[{i}] in new metrics (unordered)") return if isinstance(prev_v, float) and isinstance(new_v, float): if not numeric_close(prev_v, new_v): diffs.append(f"Numeric mismatch at {path}: {prev_v} != {new_v}") return if prev_v != new_v: diffs.append(f"Value differs at {path}: {prev_v} != {new_v}") diffs: List[str] = [] diff_values(previous_aggregate_metrics_dict, aggregate_metrics_dict, path="", diffs=diffs) if diffs: print("Differences found in aggregate metrics:") pprint(diffs) conflicting_metrics_fpath = metrics_fpath.with_name(f"{metrics_fpath.stem}_conflict.json") with open(conflicting_metrics_fpath, "w") as f: json.dump(aggregate_metrics_dict, f, indent=4) return conflicting_metrics_fpath
[docs] def validate_samples_and_aggregate_metrics( self, server_instance_configs: List[ServerInstanceConfig] ) -> Dict[str, DatasetMetrics]: conflicting_fpaths: List[str] = [] dataset_type_to_aggregate_metrics: Dict[str, DatasetMetrics] = defaultdict(DatasetMetrics) for c in server_instance_configs: for d in c.datasets: state = self._validate_samples_and_aggregate_metrics_single_dataset(d) dataset_type_to_aggregate_metrics[d.type].add(state.metrics) aggregate_metrics = state.metrics.aggregate() aggregate_metrics_dict = aggregate_metrics.model_dump(mode="json", by_alias=True) aggregate_metrics_dict = d.model_dump() | aggregate_metrics_dict data_fpath = Path(d.jsonl_fpath) metrics_fpath = data_fpath.with_name(f"{data_fpath.stem}_metrics.json") maybe_conflicting_metrics_fpath = self._validate_aggregate_metrics( aggregate_metrics_dict, metrics_fpath ) if maybe_conflicting_metrics_fpath is not None: conflicting_fpaths.append(str(maybe_conflicting_metrics_fpath)) continue with open(metrics_fpath, "w") as f: json.dump(aggregate_metrics_dict, f, indent=4) print(f"Aggregate metrics for {metrics_fpath}") pprint(aggregate_metrics_dict) if conflicting_fpaths: conflicting_fpaths_str = "\n- ".join([""] + conflicting_fpaths) target_fpaths_str = "\n- ".join( [""] + [fp.replace("_conflict.json", ".json") for fp in conflicting_fpaths] ) raise ValueError(f""" Found conflicting aggregate metrics that need to be corrected:{conflicting_fpaths_str} This could be due to a change in how metrics are calculated, leading to outdated metrics. Try deleting the below file(s) and rerunning data preparation:{target_fpaths_str} """) return dict(dataset_type_to_aggregate_metrics)
######################################## # Collate samples ########################################
[docs] def _collate_samples_single_type( self, type: DatasetType, server_instance_configs: List[ServerInstanceConfig], ) -> List[Path]: paths_to_collate = [] for c in server_instance_configs: for d in c.datasets: if d.type != type: continue data_path = Path(d.jsonl_fpath) prepare_path = data_path.with_name(f"{data_path.stem}_prepare.jsonl") with open(prepare_path, "w") as target: for line in self._iter_dataset_lines(d): d = json.loads(line) d[AGENT_REF_KEY] = AgentServerRef(type="responses_api_agents", name=c.name).model_dump() target.write(f"{json.dumps(d)}\n") paths_to_collate.append(prepare_path) return paths_to_collate
[docs] def collate_samples( self, config: TrainDataProcessorConfig, server_instance_configs: List[ServerInstanceConfig], dataset_type_to_aggregate_metrics: Dict[str, DatasetMetrics], ) -> None: final_fpaths: Dict[DatasetType, Path] = dict() conflicting_fpaths: List[str] = [] for type in config.in_scope_dataset_types: if type not in dataset_type_to_aggregate_metrics: continue aggregate_metrics = dataset_type_to_aggregate_metrics[type] aggregate_metrics = aggregate_metrics.aggregate() aggregate_metrics_dict = aggregate_metrics.model_dump(mode="json", by_alias=True) parent = Path(config.output_dirpath) parent.mkdir(exist_ok=True) metrics_fpath = parent / f"{type}_metrics.json" maybe_conflicting_metrics_fpath = self._validate_aggregate_metrics( aggregate_metrics_dict=aggregate_metrics_dict, metrics_fpath=metrics_fpath, ) if maybe_conflicting_metrics_fpath is not None: conflicting_fpaths.append(str(maybe_conflicting_metrics_fpath)) continue with open(metrics_fpath, "w") as f: json.dump(aggregate_metrics_dict, f, indent=4) paths_to_collate = self._collate_samples_single_type( type=type, server_instance_configs=server_instance_configs, ) collated_fpath = parent / f"{type}.jsonl" with open(collated_fpath, "wb") as outfile: for path in tqdm(paths_to_collate, desc=f"Collating {type} datasets"): with open(path, "rb") as infile: copyfileobj(infile, outfile) print(f"Aggregate metrics for {metrics_fpath}") pprint(aggregate_metrics_dict) final_fpaths[type] = collated_fpath if conflicting_fpaths: conflicting_fpaths_str = "\n- ".join([""] + conflicting_fpaths) target_fpaths_str = "\n- ".join( [""] + [fp.replace("_conflict.json", ".json") for fp in conflicting_fpaths] ) raise ValueError(f""" Found conflicting aggregate metrics that need to be corrected:{conflicting_fpaths_str} This could be due to a change in how metrics are calculated, leading to outdated metrics. Try deleting the below file(s) and rerunning data preparation:{target_fpaths_str} """) final_fpaths_str = "\n- ".join([""] + [f"{type}: {fpath}" for type, fpath in final_fpaths.items()]) print(f"View your final data!{final_fpaths_str}")
[docs] def validate_backend_credentials(backend: str) -> tuple[bool, str]: """Check if required env variables are present for the chosen backend""" global_config = get_global_config_dict() if backend == "gitlab": required = ["mlflow_tracking_uri", "mlflow_tracking_token"] missing = [k for k in required if k not in global_config or not global_config[k]] if missing: return False, ( f"GitLab backend selected but missing credentials: {missing}\n" f"Add to env.yaml:\n" f" mlflow_tracking_uri: <your_gitlab_uri>\n" f" mlflow_tracking_token: <your_gitlab_token>" ) elif backend == "huggingface": required = ["hf_token"] missing = [k for k in required if k not in global_config or not global_config[k]] if missing: return False, ( f"HuggingFace backend selected but missing credentials: {missing}\n" f"Add to env.yaml:\n" f" hf_token: <your_hf_token>\n" ) return True, ""
[docs] def prepare_data(): # pragma: no cover global_config_dict = get_global_config_dict( global_config_dict_parser_config=GlobalConfigDictParserConfig( initial_global_config_dict=GlobalConfigDictParserConfig.NO_MODEL_GLOBAL_CONFIG_DICT, ) ) data_processor = TrainDataProcessor() data_processor.run(global_config_dict)