# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from asyncio import sleep
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Required,
TypeAlias,
Union,
)
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartTextParam,
ChatCompletionDeveloperMessageParam,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_assistant_message_param import (
ContentArrayOfContentPart,
)
from openai.types.chat.completion_create_params import (
ChatCompletionAudioParam,
ChatCompletionPredictionContentParam,
ChatCompletionStreamOptionsParam,
ChatCompletionToolChoiceOptionParam,
ReasoningEffort,
ResponseFormat,
WebSearchOptions,
)
from openai.types.responses import (
FunctionToolParam,
Response,
ResponseInputTextParam,
)
from openai.types.responses.response_create_params import (
Metadata,
Reasoning,
ResponseIncludable,
ResponsePromptParam,
ResponsesModel,
ResponseTextConfigParam,
ToolChoice,
ToolParam,
)
from openai.types.responses.response_input_param import (
ResponseInputMessageContentListParam,
)
from openai.types.responses.response_output_text_param import Annotation, Logprob
from openai.types.responses.response_reasoning_item import (
Summary,
)
from openai.types.shared.chat_model import ChatModel
from openai.types.shared_params import FunctionDefinition
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict
from nemo_gym.server_utils import MAX_NUM_TRIES, ClientResponse, raise_for_status, request
########################################
# Training-specific
########################################
[docs]
class TokenIDLogProbMixin(BaseModel):
prompt_token_ids: List[int]
generation_token_ids: List[int]
generation_log_probs: List[float]
[docs]
class TokenIDLogProbTypedDictMixin(TypedDict):
prompt_token_ids: List[int]
generation_token_ids: List[int]
generation_log_probs: List[float]
########################################
# Responses API inputs
########################################
[docs]
class NeMoGymSummary(Summary):
pass
[docs]
class NeMoGymResponseReasoningItem(BaseModel):
id: str
# Override the Iterable to avoid lazy iterators in Pydantic validation.
summary: List[NeMoGymSummary]
type: Literal["reasoning"] = "reasoning"
encrypted_content: Optional[str] = None
# As of Wed Sep 17, 2025, the OpenAI API with GPT-5 returns None for this status rather than a valid value here.
# On subsequent calls to the OpenAI endpoints within a rollout, the status parameter is not accepted i.e. the OpenAI API returns a bad request when the status parameter is populated.
# It's not clear whether or not this is intended. We comment out this status parameter here as a quick stop-gap to fix this issue in Gym re-queries.
# status: Optional[Literal["in_progress", "completed", "incomplete"]] = None
[docs]
class NeMoGymResponseOutputText(BaseModel):
# Override the Iterable to avoid lazy iterators in Pydantic validation.
annotations: List[Annotation]
text: str
type: Literal["output_text"] = "output_text"
logprobs: Optional[List[Logprob]] = None
[docs]
class NeMoGymResponseOutputRefusal(BaseModel):
refusal: str
type: Literal["refusal"] = "refusal"
NeMoGymContent: TypeAlias = Union[NeMoGymResponseOutputText, NeMoGymResponseOutputRefusal]
[docs]
class NeMoGymResponseOutputMessage(BaseModel):
id: str
# Override the Iterable to avoid lazy iterators in Pydantic validation.
content: List[NeMoGymContent]
role: Literal["assistant"] = "assistant"
status: Literal["in_progress", "completed", "incomplete"] = "completed"
type: Literal["message"] = "message"
[docs]
class NeMoGymMessage(BaseModel):
content: ResponseInputMessageContentListParam
role: Literal["user", "system", "developer"]
status: Literal["in_progress", "completed", "incomplete"] = "completed"
type: Literal["message"] = "message"
[docs]
class NeMoGymFunctionCallOutput(BaseModel):
"""
We copy openai.types.responses.response_input_param.FunctionCallOutput, originally a TypedDict, as a BaseModel here
so that we can use it in the NeMoGymResponseOutputItem below and be consistent with the other ResponseOutputItem types.
"""
call_id: str
output: str
type: Literal["function_call_output"] = "function_call_output"
id: Optional[str] = None
status: Optional[Literal["in_progress", "completed", "incomplete"]] = None
[docs]
class NeMoGymResponseInputText(ResponseInputTextParam):
pass
[docs]
class NeMoGymMessageForTraining(NeMoGymMessage, TokenIDLogProbMixin):
pass
[docs]
class NeMoGymResponseOutputMessageForTraining(NeMoGymResponseOutputMessage, TokenIDLogProbMixin):
pass
[docs]
class NeMoGymResponseReasoningItemForTraining(NeMoGymResponseReasoningItem, TokenIDLogProbMixin):
pass
RESPONSES_TO_TRAIN = {
NeMoGymEasyInputMessage: NeMoGymEasyInputMessageForTraining,
NeMoGymMessage: NeMoGymMessageForTraining,
NeMoGymResponseOutputMessage: NeMoGymResponseOutputMessageForTraining,
NeMoGymResponseFunctionToolCall: NeMoGymResponseFunctionToolCallForTraining,
NeMoGymResponseReasoningItem: NeMoGymResponseReasoningItemForTraining,
}
NeMoGymResponseInputItem = Union[
NeMoGymEasyInputMessage,
NeMoGymMessage,
NeMoGymResponseOutputMessage,
NeMoGymResponseFunctionToolCall,
NeMoGymFunctionCallOutput,
NeMoGymResponseReasoningItem,
# For training:
NeMoGymEasyInputMessageForTraining,
NeMoGymMessageForTraining,
NeMoGymResponseOutputMessageForTraining,
NeMoGymResponseFunctionToolCallForTraining,
NeMoGymResponseReasoningItemForTraining,
]
NeMoGymResponseInput: TypeAlias = List[NeMoGymResponseInputItem]
[docs]
class NeMoGymResponseCreateParamsNonStreaming(BaseModel):
"""
This class is a copy of openai.types.responses.response_create_params.ResponseCreateParamsNonStreaming
We make a copy of it here since ResponseCreateParamsNonStreaming is a TypedDict with no strict validation.
We need to do server side validation here.
"""
model_config = ConfigDict(extra="forbid")
background: Optional[bool] = None
include: Optional[List[ResponseIncludable]] = None
input: Union[str, NeMoGymResponseInput]
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
metadata: Optional[Metadata] = None
model: Optional[ResponsesModel] = None
parallel_tool_calls: bool = True # OpenAI default
previous_response_id: Optional[str] = None
prompt: Optional[ResponsePromptParam] = None
reasoning: Optional[Reasoning] = None
service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = None
store: Optional[bool] = None
temperature: Optional[float] = None
text: Optional[ResponseTextConfigParam] = None
tool_choice: ToolChoice = "auto" # OpenAI default
# Override the Iterable to avoid lazy iterators in Pydantic validation.
tools: List[ToolParam] = Field(default_factory=list)
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
truncation: Optional[Literal["auto", "disabled"]] = None
user: Optional[str] = None
stream: Optional[Literal[False]] = None
########################################
# Responses API outputs
########################################
NeMoGymResponseOutputItem = NeMoGymResponseInputItem
[docs]
class NeMoGymResponse(Response):
output: List[NeMoGymResponseOutputItem]
########################################
# Chat Completion API outputs
########################################
[docs]
class NeMoGymFunction(BaseModel):
arguments: str
name: str
[docs]
class NeMoGymChatCompletionMessage(ChatCompletionMessage):
tool_calls: Optional[List[NeMoGymChatCompletionMessageToolCall]] = None
[docs]
class NeMoGymChatCompletionMessageForTraining(NeMoGymChatCompletionMessage, TokenIDLogProbMixin):
pass
[docs]
class NeMoGymChoice(Choice):
message: Union[NeMoGymChatCompletionMessage, NeMoGymChatCompletionMessageForTraining]
[docs]
class NeMoGymChatCompletion(ChatCompletion):
choices: List[NeMoGymChoice]
########################################
# Chat Completion API inputs
########################################
[docs]
class NeMoGymFunctionDefinition(FunctionDefinition):
pass
[docs]
class NeMoGymChatCompletionContentPartTextParam(ChatCompletionContentPartTextParam):
pass
[docs]
class NeMoGymChatCompletionUserMessageParam(ChatCompletionUserMessageParam):
# Override the iterable which is annoying to work with.
content: Required[Union[str, List[NeMoGymChatCompletionContentPartTextParam]]]
[docs]
class NeMoGymChatCompletionSystemMessageParam(ChatCompletionSystemMessageParam):
# Override the iterable which is annoying to work with.
content: Required[Union[str, List[NeMoGymChatCompletionContentPartTextParam]]]
[docs]
class NeMoGymChatCompletionDeveloperMessageParam(ChatCompletionDeveloperMessageParam):
# Override the iterable which is annoying to work with.
content: Required[Union[str, List[NeMoGymChatCompletionContentPartTextParam]]]
[docs]
class NeMoGymChatCompletionAssistantMessageParam(ChatCompletionAssistantMessageParam, total=False):
# Override the iterable which is annoying to work with.
content: Union[str, List[ContentArrayOfContentPart], None]
tool_calls: Optional[List[NeMoGymChatCompletionMessageToolCallParam]] = None
[docs]
class NeMoGymChatCompletionAssistantMessageForTrainingParam(
NeMoGymChatCompletionAssistantMessageParam, TokenIDLogProbTypedDictMixin
):
pass
NeMoGymChatCompletionMessageParam: TypeAlias = Union[
NeMoGymChatCompletionDeveloperMessageParam,
NeMoGymChatCompletionSystemMessageParam,
NeMoGymChatCompletionUserMessageParam,
NeMoGymChatCompletionAssistantMessageParam,
NeMoGymChatCompletionToolMessageParam,
# Don't add deprecated.
# NeMoGymChatCompletionFunctionMessageParam,
# Training:
NeMoGymChatCompletionAssistantMessageForTrainingParam,
]
[docs]
class NeMoGymChatCompletionCreateParamsNonStreaming(BaseModel):
messages: List[NeMoGymChatCompletionMessageParam]
model: Optional[Union[str, ChatModel]] = None
audio: Optional[ChatCompletionAudioParam] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = None
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
metadata: Optional[Metadata] = None
modalities: Optional[List[Literal["text", "audio"]]] = None
n: Optional[int] = None
parallel_tool_calls: bool = True # OpenAI default
prediction: Optional[ChatCompletionPredictionContentParam] = None
presence_penalty: Optional[float] = None
reasoning_effort: Optional[ReasoningEffort] = None
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = None
stop: Union[Optional[str], List[str], None] = None
store: Optional[bool] = None
stream_options: Optional[ChatCompletionStreamOptionsParam] = None
temperature: Optional[float] = None
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
tools: Optional[List[NeMoGymChatCompletionToolParam]] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
user: Optional[str] = None
web_search_options: Optional[WebSearchOptions] = None
stream: Optional[Literal[False]] = None
# Disallow deprecated args
# function_call: FunctionCall
# functions: Iterable[Function]
########################################
# Clients
########################################
# See https://platform.openai.com/docs/guides/error-codes/api-errors
# 500 is internal server error, which may sporadically occur
# 502 is Bad gateway (when the endpoint is overloaded)
# 504 is Gateway timeout (when the endpoint config has too low of a gateway timeout setting for the model to finish generating)
RATE_LIMIT_ERROR_CODES = [429, 502, 503, 504, 520]
RETRY_ERROR_CODES = RATE_LIMIT_ERROR_CODES + [500]
[docs]
class NeMoGymAsyncOpenAI(BaseModel): # pragma: no cover
"""This is just a stub class that wraps around aiohttp"""
base_url: str
api_key: str
internal: bool = Field(
default=False,
description="Set this to true if this particular client is only used to call internal NeMo Gym servers.",
)
[docs]
async def _request(self, **request_kwargs: Dict) -> ClientResponse:
max_num_tries = MAX_NUM_TRIES
tries = 0
while tries < MAX_NUM_TRIES:
tries += 1
response = await request(**(request_kwargs | {"_internal": self.internal}))
if response.status in RETRY_ERROR_CODES:
# If we hit a rate limit, we don't want to hit max num tries, so we increment both.
if response.status in RATE_LIMIT_ERROR_CODES:
max_num_tries += 1
content = (await response.content.read()).decode()
print(
f"Hit a {response.status} trying to query an OpenAI endpoint (try {tries}). Sleeping 0.5s. Error message: {content}"
)
await sleep(0.5)
continue
else:
return response
# We've exited the loop
response.raise_for_status()
[docs]
async def _raise_for_status(self, response: ClientResponse, request_kwargs: Dict[str, Any]) -> None:
if not response.ok:
print(f"Request kwargs: {json.dumps(request_kwargs)}")
await raise_for_status(response)
[docs]
async def create_chat_completion(self, **kwargs):
request_kwargs = dict(
url=f"{self.base_url}/chat/completions",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
response = await self._request(method="POST", **request_kwargs)
await self._raise_for_status(response, request_kwargs)
return await response.json()
[docs]
async def create_response(self, **kwargs):
request_kwargs = dict(
url=f"{self.base_url}/responses",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
response = await self._request(method="POST", **request_kwargs)
await self._raise_for_status(response, request_kwargs)
return await response.json()
[docs]
async def create_tokenize(self, **kwargs):
base_url = self.base_url.removesuffix("/v1")
request_kwargs = dict(
url=f"{base_url}/tokenize",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
response = await self._request(method="POST", **request_kwargs)
await self._raise_for_status(response, request_kwargs)
return await response.json()