nemo_rl.models.policy.workers.patches#
Module Contents#
Functions#
Return absolute path to a Transformer Engine file or raise if it cannot be found. |
|
Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. |
API#
- nemo_rl.models.policy.workers.patches._get_transformer_engine_file(relative_path: str) str#
Return absolute path to a Transformer Engine file or raise if it cannot be found.
The relative_path should be a POSIX-style path under the transformer_engine package root, e.g. “pytorch/triton/permutation.py”.
- nemo_rl.models.policy.workers.patches.apply_transformer_engine_patch()#
Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files.
This locates the target file via importlib metadata instead of importing
transformer_engine, to avoid side effects during initialization. If the permutation module has already been imported, it will be reloaded so that the patched source takes effect.