# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, List
from gradio import JSON, Blocks, Chatbot, ChatMessage, Dropdown
from gradio.components.chatbot import MetadataDict
from openai.types.responses.response_input_param import (
EasyInputMessageParam,
FunctionCallOutput,
ResponseFunctionToolCallParam,
ResponseInputItemParam,
ResponseReasoningItemParam,
)
from pydantic import ConfigDict, Field
from tqdm.auto import tqdm
from nemo_gym.base_resources_server import BaseVerifyResponse
from nemo_gym.config_types import BaseNeMoGymCLIConfig
from nemo_gym.server_utils import get_global_config_dict
from nemo_gym.train_data_utils import (
DatasetMetrics,
aggregate_other_metrics,
compute_sample_metrics,
postprocess_other_metrics,
)
[docs]
class DatasetViewerVerifyResponse(BaseVerifyResponse):
model_config = ConfigDict(extra="allow")
[docs]
def convert_single_message(m: ResponseInputItemParam) -> List[ChatMessage]:
if not m.get("type") and m.get("role"):
m["type"] = "message"
match m["type"]:
case "function_call_output": # "tool"
return format_function_call_output(m)
case "function_call": # "assistant tool call"
return format_function_call(m)
case "message": # "assistant chat"
return format_message(m)
case "reasoning": # "assistant reasoning"
return format_reasoning(m)
case _: # pragma: no cover
raise NotImplementedError(f"Unsupported message type: {m}")
[docs]
def rollout_to_messages(create_params: dict, response: dict) -> List[ChatMessage]:
messages = []
sampling_params = create_params.copy()
sampling_params.pop("input")
sampling_params.pop("tools", None)
messages.append(
ChatMessage(
content=f"""```json
{json.dumps(sampling_params, indent=4)}
```""",
role="assistant",
metadata=MetadataDict(title="Sampling params", status="done"),
)
)
if create_params.get("tools"):
messages.append(
ChatMessage(
content=f"""```json
{json.dumps(create_params.get("tools"), indent=4)}
```""",
role="assistant",
metadata=MetadataDict(title="Tools", status="done"),
)
)
input = create_params["input"]
if isinstance(input, str):
input = [{"role": "user", "content": input}]
turn = 0
step = 0
for m in input + response["output"]:
if m.get("role") == "user":
turn += 1
step = 0
if m.get("type") == "function_call":
step += 1
for message in convert_single_message(m):
message.metadata["title"] = f"Turn {turn} Step {step} - {message.metadata['title']}"
messages.append(message)
return messages
[docs]
class JsonlDatasetViewerConfig(BaseNeMoGymCLIConfig):
"""
Launch a Gradio interface to view and explore dataset rollouts interactively.
Examples:
```bash
# Launch viewer with default settings (accessible from localhost only)
ng_viewer +jsonl_fpath=weather_rollouts.jsonl
# Accept requests from anywhere (e.g., for remote access)
ng_viewer +jsonl_fpath=weather_rollouts.jsonl +server_host=0.0.0.0
# Use a custom port
ng_viewer +jsonl_fpath=weather_rollouts.jsonl +server_port=8080
```
"""
jsonl_fpath: str = Field(description="Filepath to a local jsonl file to view.")
server_host: str | None = Field(
default=None,
description='Network address where the viewer accepts requests. Defaults to "127.0.0.1" (localhost only). Set to "0.0.0.0" to accept requests from anywhere.',
)
server_port: int | None = Field(
default=None,
description="Port where the viewer accepts requests. Defaults to 7860. If the specified port is unavailable, Gradio will search for the next available port.",
)
[docs]
def get_aggregate_metrics(data: List[DatasetViewerVerifyResponse]) -> Dict[str, Any]:
dataset_metrics = DatasetMetrics()
other_metrics = {}
for line in data:
line = json.dumps(line.model_dump())
metrics, is_offending = compute_sample_metrics(line)
if not is_offending:
dataset_metrics.add(metrics)
sample_dict = json.loads(line)
aggregate_other_metrics(other_metrics, sample_dict)
postprocess_other_metrics(dataset_metrics, other_metrics)
aggregate_metrics = dataset_metrics.aggregate()
aggregate_metrics_dict = aggregate_metrics.model_dump(by_alias=True)
return aggregate_metrics_dict
[docs]
def build_jsonl_dataset_viewer(config: JsonlDatasetViewerConfig) -> Blocks:
with open(config.jsonl_fpath) as f:
data = list(
tqdm(
map(DatasetViewerVerifyResponse.model_validate_json, f),
desc="Loading data",
)
)
choices = [(f"Sample {i + 1} - Responses ID {d.response.id}", i) for i, d in enumerate(data)]
def select_item(value: int):
d = data[value]
return extra_info_to_messages(d.model_dump()) + rollout_to_messages(
d.responses_create_params.model_dump(), d.response.model_dump()
)
CSS = """
footer {
visibility: hidden;
}
"""
with Blocks(analytics_enabled=False, css=CSS) as demo:
aggregate_dicts = get_aggregate_metrics(data)
JSON(value=aggregate_dicts, label="Aggregate Metrics", open=False)
item_dropdown = Dropdown(choices=choices, value=0, label="Samples")
chatbot = Chatbot(
value=select_item(0),
type="messages",
height="80vh",
layout="panel",
label="Rollout",
)
item_dropdown.select(fn=select_item, inputs=item_dropdown, outputs=chatbot, show_api=False)
return demo
[docs]
def main(): # pragma: no cover
config = JsonlDatasetViewerConfig.model_validate(get_global_config_dict())
demo = build_jsonl_dataset_viewer(config)
demo.launch(server_name=config.server_host, server_port=config.server_port, enable_monitoring=False)