nemo_rl.models.policy.workers.patches#

Module Contents#

Functions#

_get_transformer_engine_file

Return absolute path to a Transformer Engine file or raise if it cannot be found.

apply_transformer_engine_patch

Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files.

API#

nemo_rl.models.policy.workers.patches._get_transformer_engine_file(relative_path: str) str#

Return absolute path to a Transformer Engine file or raise if it cannot be found.

The relative_path should be a POSIX-style path under the transformer_engine package root, e.g. “pytorch/triton/permutation.py”.

nemo_rl.models.policy.workers.patches.apply_transformer_engine_patch()#

Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files.

This locates the target file via importlib metadata instead of importing transformer_engine, to avoid side effects during initialization. If the permutation module has already been imported, it will be reloaded so that the patched source takes effect.