Add entrypoint for flagos multi-backend plugin system#3107
Conversation
Greptile SummaryThis PR wires up a plugin entrypoint for the FlagOS multi-backend system. At
Confidence Score: 5/5Safe to merge; all previously blocking issues have been addressed and the remaining nits do not affect correctness on any current code path. The plugin block is torch-only, uses a namespaced import, wraps everything in a broad exception handler, correctly captures and restores sys.modules state on failure, and falls back gracefully via getattr in the attention module. No current call site will be silently broken. Both changed files are straightforward; the dpa_utils monkey-patch in dot_product_attention.py is worth keeping in mind if new call sites import get_attention_backend by name in the future. Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant PyTorchInit as pytorch/__init__.py
participant CommonInit as common/__init__.py
participant SysModules as sys.modules
participant Plugin as transformer_engine_plugin_fl
participant DPA as dot_product_attention.py
User->>PyTorchInit: import transformer_engine.pytorch
PyTorchInit->>CommonInit: load_framework_extension("torch")
CommonInit->>SysModules: "set transformer_engine_torch = solib (native)"
alt "NVTE_ENABLE_PLUGIN=1"
CommonInit->>CommonInit: "_original_module = sys.modules.get(module_name)"
CommonInit->>Plugin: from transformer_engine_plugin_fl import load_plugins
CommonInit->>SysModules: "set transformer_engine_torch_nv = solib"
CommonInit->>Plugin: load_plugins()
alt load_plugins() succeeds
Plugin-->>SysModules: replace transformer_engine_torch with plugin stub
Plugin-->>CommonInit: return
else load_plugins() raises
CommonInit->>SysModules: pop transformer_engine_torch_nv
CommonInit->>SysModules: "restore transformer_engine_torch = _original_module"
CommonInit-->>User: warnings.warn(RuntimeWarning)
end
end
CommonInit-->>PyTorchInit: return
PyTorchInit->>DPA: "import dot_product_attention (tex = transformer_engine_torch)"
alt "NVTE_ENABLE_PLUGIN=1"
DPA->>DPA: "FlashAttention = getattr(tex, flash_attention, FlashAttentionNative)"
DPA->>DPA: "_plugin_get_attention_backend = getattr(tex, get_attention_backend, None)"
opt plugin backend found
DPA->>DPA: "dpa_utils.get_attention_backend = _plugin_get_attention_backend"
end
end
DPA-->>User: DotProductAttention ready
Reviews (5): Last reviewed commit: "fix: complete sys.modules rollback on pl..." | Re-trigger Greptile |
| _FlashAttentionNative = FlashAttention | ||
| FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative) | ||
| dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend | ||
| dpa_utils.get_attention_backend = tex.get_attention_backend |
There was a problem hiding this comment.
Unguarded
AttributeError crashes the entire TE import
tex.get_attention_backend is a bare attribute access with no fallback. When NVTE_ENABLE_PLUGIN=1 but the plugin import fails (caught and swallowed by the except ImportError in common/__init__.py), sys.modules["transformer_engine_torch"] still points to the native CUDA pybind module. That module has no get_attention_backend attribute — raising AttributeError at module-import time and making all of TransformerEngine unusable. Unlike line 154 which uses getattr(..., _FlashAttentionNative) as a safe fallback, line 156 has no equivalent guard.
The two changes compound: the except in __init__.py prints a message and continues, silently leaving the env var active, and then this module-level code fatally re-runs against the unmodified native module.
| if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1": | ||
| sys.modules[module_name + "_nv"] = solib | ||
| try: | ||
| from plugin import load_plugins |
There was a problem hiding this comment.
Bare
plugin module name is too generic
from plugin import load_plugins will import any top-level package named plugin already present in the Python environment before it finds the intended TE plugin. This could silently call the wrong load_plugins function or raise a confusing ImportError message referencing an unrelated package. The module should be imported under a namespaced name (e.g., transformer_engine_plugin) so the intent is unambiguous and collision-free.
| # Plugin system: if NVTE_ENABLE_PLUGIN=1, let plugin stub take over | ||
| # transformer_engine_torch and register original pybind as _nv for CUDA backend. | ||
| if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1": | ||
| sys.modules[module_name + "_nv"] = solib | ||
| try: | ||
| from plugin import load_plugins | ||
| load_plugins() | ||
| except ImportError as e: | ||
| print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}") |
There was a problem hiding this comment.
Plugin block runs for all frameworks, not just torch
load_framework_extension is called for both "torch" and "jax" (see transformer_engine/jax/__init__.py). The _nv alias registration and load_plugins() call are torch-specific constructs, but this block will also execute when the JAX extension is loaded. This registers a meaningless transformer_engine_jax_nv entry in sys.modules and re-invokes load_plugins() a second time, which may double-register backends or trigger unintended side effects.
| except ImportError as e: | ||
| print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}") |
There was a problem hiding this comment.
Silent
print to stdout is easy to miss in production logging pipelines. Using warnings.warn (or at minimum writing to sys.stderr) ensures the failure is visible under standard Python warning filters and won't be swallowed by log-capture tools.
| except ImportError as e: | |
| print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}") | |
| except ImportError as e: | |
| import warnings | |
| warnings.warn( | |
| f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}", | |
| RuntimeWarning, | |
| stacklevel=2, | |
| ) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| from plugin import load_plugins | ||
| load_plugins() |
There was a problem hiding this comment.
Is the load_plugins performing a generic plugin discovery and loading? If so, we could have that in the main TE repo as well. If it is specific to FlagOS then I agree that it should live there.
There was a problem hiding this comment.
Thanks for the review. Two questions I'd like to confirm before deciding on the next step:
- Should the
load_pluginsentrypoint be merged into themainbranch, so it can follow future NVIDIA TE releases? - Should the generic framework components in
transformer_engine_plugin_fl(e.g.,OpManager,SelectionPolicy) also be upstreamed into TE, or should we keep the current approach where TE only exposes a minimalload_pluginshook and all dispatch logic stays external?
|
Hi @lxd-cumt , thank you for the contribution. A few things:
|
1886932 to
727a116
Compare
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
727a116 to
c590d4d
Compare
Thanks for the review! I've signed off on both commits and addressed the Greptile comments. |
| sys.modules[module_name + "_nv"] = solib | ||
| try: | ||
| from transformer_engine_plugin_fl import load_plugins | ||
| load_plugins() | ||
| except ImportError as e: | ||
| warnings.warn( | ||
| f"NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}", | ||
| ImportWarning, | ||
| stacklevel=2, | ||
| ) |
There was a problem hiding this comment.
Plugin initialization errors beyond
ImportError left uncaught
sys.modules[module_name + "_nv"] is registered unconditionally at line 199 before load_plugins() runs. If transformer_engine_plugin_fl is installed but load_plugins() itself raises a non-ImportError (e.g., RuntimeError during backend registration, an AttributeError inside the plugin, or an OSError loading a shared library), the exception escapes the try/except block, crashes load_framework_extension, and leaves sys.modules in a partially inconsistent state: transformer_engine_torch_nv exists (native pybind) but transformer_engine_torch was never replaced by the plugin stub. TE initialization fails with an opaque traceback instead of the intended graceful fallback message.
3cbb72d to
9a869ec
Compare
| try: | ||
| from transformer_engine_plugin_fl import load_plugins | ||
| sys.modules[module_name + "_nv"] = solib | ||
| load_plugins() | ||
| except Exception as e: | ||
| # Rollback _nv registration if plugin failed to fully initialize | ||
| sys.modules.pop(module_name + "_nv", None) | ||
| warnings.warn( | ||
| f"NVTE_ENABLE_PLUGIN=1 but plugin loading failed: {e}", | ||
| RuntimeWarning, | ||
| stacklevel=2, | ||
| ) |
There was a problem hiding this comment.
Incomplete
sys.modules rollback on load_plugins() failure
sys.modules[module_name] (i.e., transformer_engine_torch) is set to solib at line 192, before this try block. If load_plugins() partially succeeds — for example, replaces sys.modules["transformer_engine_torch"] with the plugin stub before raising a RuntimeError during backend registration — the except block pops _nv but leaves sys.modules["transformer_engine_torch"] pointing to the partially-initialized stub. TE then continues with a broken tex module even though the warning says plugin loading failed.
Capture the pre-attempt value of sys.modules.get(module_name) before calling load_plugins() and restore it in the except block alongside the _nv pop.
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
|
The |
FlagOS Proposal: Plugin Architecture & Device-Agnostic Abstraction for TransformerEngine
Plugin System: Initialization-time Backend Loading
We propose a plugin architecture where TransformerEngine (TE) loads backend implementations at initialization time via an explicit plugin interface, while the actual multi-backend plugins reside in a separate repository (
TransformerEngine-Plugin-FL).TransformerEngine-Plugin-FL: https://github.com/lxd-cumt/TransformerEngine-Plugin-FL
Current State
TE already has a prototype (
NVTE_ENABLE_PLUGIN=1incommon/__init__.py) that registers the original CUDA pybind module astransformer_engine_torch_nvand delegates to an externalload_plugins()entry point.Proposed Design
load_framework_extension()time, if a plugin is present, TE dispatches backend calls through the plugin registry; otherwise it falls back to the native te implementation.TransformerEngine-Plugin-FLrepository is independently installable and contains multiple backend implementations for diverse accelerators, and support more training scenarios.