Skip to content

Add entrypoint for flagos multi-backend plugin system#3107

Open
lxd-cumt wants to merge 4 commits into
NVIDIA:release_v2.14from
lxd-cumt:rc2.14_with_flagos
Open

Add entrypoint for flagos multi-backend plugin system#3107
lxd-cumt wants to merge 4 commits into
NVIDIA:release_v2.14from
lxd-cumt:rc2.14_with_flagos

Conversation

@lxd-cumt

@lxd-cumt lxd-cumt commented Jun 9, 2026

Copy link
Copy Markdown

FlagOS Proposal: Plugin Architecture & Device-Agnostic Abstraction for TransformerEngine

Plugin System: Initialization-time Backend Loading

We propose a plugin architecture where TransformerEngine (TE) loads backend implementations at initialization time via an explicit plugin interface, while the actual multi-backend plugins reside in a separate repository (TransformerEngine-Plugin-FL).

TransformerEngine-Plugin-FL: https://github.com/lxd-cumt/TransformerEngine-Plugin-FL

Current State

TE already has a prototype (NVTE_ENABLE_PLUGIN=1 in common/__init__.py) that registers the original CUDA pybind module as transformer_engine_torch_nv and delegates to an external load_plugins() entry point.

Proposed Design

  • TE defines a stable plugin API contract (operator signatures, quantization interfaces, communication primitives).
  • At load_framework_extension() time, if a plugin is present, TE dispatches backend calls through the plugin registry; otherwise it falls back to the native te implementation.
  • The TransformerEngine-Plugin-FL repository is independently installable and contains multiple backend implementations for diverse accelerators, and support more training scenarios.

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 9, 2026
@greptile-apps

greptile-apps Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR wires up a plugin entrypoint for the FlagOS multi-backend system. At load_framework_extension("torch") time, if NVTE_ENABLE_PLUGIN=1, TE registers the native CUDA pybind module as transformer_engine_torch_nv and delegates to transformer_engine_plugin_fl.load_plugins(); on any failure it rolls back sys.modules and emits a RuntimeWarning. A matching module-level block in dot_product_attention.py optionally overrides FlashAttention and dpa_utils.get_attention_backend with plugin-provided implementations.

  • common/__init__.py: All previously flagged concerns (bare module name, JAX double-invocation, stdout print, narrow ImportError catch, incomplete sys.modules rollback) have been addressed with namespaced import, framework == \"torch\" guard, warnings.warn, broad except Exception, and pre-attempt capture of _original_module.
  • dot_product_attention.py: Both overrides use getattr with safe fallbacks, preventing the AttributeError on the native module that was flagged earlier. The dpa_utils.get_attention_backend patch mutates the module object and is correct for attribute-access call sites, but would miss any future call sites that import the function directly by name.

Confidence Score: 5/5

Safe to merge; all previously blocking issues have been addressed and the remaining nits do not affect correctness on any current code path.

The plugin block is torch-only, uses a namespaced import, wraps everything in a broad exception handler, correctly captures and restores sys.modules state on failure, and falls back gracefully via getattr in the attention module. No current call site will be silently broken.

Both changed files are straightforward; the dpa_utils monkey-patch in dot_product_attention.py is worth keeping in mind if new call sites import get_attention_backend by name in the future.

Important Files Changed

Filename Overview
transformer_engine/common/init.py Adds the plugin loading block: namespaced import, torch-only guard, broad exception handling with sys.modules rollback, and warnings.warn — addressing all previously raised concerns. One dead else-branch in the rollback path (always non-None) is the only residual nit.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Adds module-level plugin overrides for FlashAttention and get_attention_backend, both with safe getattr fallbacks. The dpa_utils patch is effective for attribute-access call sites but would silently miss any direct-name importers.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant PyTorchInit as pytorch/__init__.py
    participant CommonInit as common/__init__.py
    participant SysModules as sys.modules
    participant Plugin as transformer_engine_plugin_fl
    participant DPA as dot_product_attention.py

    User->>PyTorchInit: import transformer_engine.pytorch
    PyTorchInit->>CommonInit: load_framework_extension("torch")
    CommonInit->>SysModules: "set transformer_engine_torch = solib (native)"
    alt "NVTE_ENABLE_PLUGIN=1"
        CommonInit->>CommonInit: "_original_module = sys.modules.get(module_name)"
        CommonInit->>Plugin: from transformer_engine_plugin_fl import load_plugins
        CommonInit->>SysModules: "set transformer_engine_torch_nv = solib"
        CommonInit->>Plugin: load_plugins()
        alt load_plugins() succeeds
            Plugin-->>SysModules: replace transformer_engine_torch with plugin stub
            Plugin-->>CommonInit: return
        else load_plugins() raises
            CommonInit->>SysModules: pop transformer_engine_torch_nv
            CommonInit->>SysModules: "restore transformer_engine_torch = _original_module"
            CommonInit-->>User: warnings.warn(RuntimeWarning)
        end
    end
    CommonInit-->>PyTorchInit: return
    PyTorchInit->>DPA: "import dot_product_attention (tex = transformer_engine_torch)"
    alt "NVTE_ENABLE_PLUGIN=1"
        DPA->>DPA: "FlashAttention = getattr(tex, flash_attention, FlashAttentionNative)"
        DPA->>DPA: "_plugin_get_attention_backend = getattr(tex, get_attention_backend, None)"
        opt plugin backend found
            DPA->>DPA: "dpa_utils.get_attention_backend = _plugin_get_attention_backend"
        end
    end
    DPA-->>User: DotProductAttention ready
Loading

Reviews (5): Last reviewed commit: "fix: complete sys.modules rollback on pl..." | Re-trigger Greptile

_FlashAttentionNative = FlashAttention
FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative)
dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend
dpa_utils.get_attention_backend = tex.get_attention_backend

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unguarded AttributeError crashes the entire TE import

tex.get_attention_backend is a bare attribute access with no fallback. When NVTE_ENABLE_PLUGIN=1 but the plugin import fails (caught and swallowed by the except ImportError in common/__init__.py), sys.modules["transformer_engine_torch"] still points to the native CUDA pybind module. That module has no get_attention_backend attribute — raising AttributeError at module-import time and making all of TransformerEngine unusable. Unlike line 154 which uses getattr(..., _FlashAttentionNative) as a safe fallback, line 156 has no equivalent guard.

The two changes compound: the except in __init__.py prints a message and continues, silently leaving the env var active, and then this module-level code fatally re-runs against the unmodified native module.

Comment thread transformer_engine/common/__init__.py Outdated
if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1":
sys.modules[module_name + "_nv"] = solib
try:
from plugin import load_plugins

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Bare plugin module name is too generic

from plugin import load_plugins will import any top-level package named plugin already present in the Python environment before it finds the intended TE plugin. This could silently call the wrong load_plugins function or raise a confusing ImportError message referencing an unrelated package. The module should be imported under a namespaced name (e.g., transformer_engine_plugin) so the intent is unambiguous and collision-free.

Comment thread transformer_engine/common/__init__.py Outdated
Comment on lines +194 to +202
# Plugin system: if NVTE_ENABLE_PLUGIN=1, let plugin stub take over
# transformer_engine_torch and register original pybind as _nv for CUDA backend.
if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1":
sys.modules[module_name + "_nv"] = solib
try:
from plugin import load_plugins
load_plugins()
except ImportError as e:
print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Plugin block runs for all frameworks, not just torch

load_framework_extension is called for both "torch" and "jax" (see transformer_engine/jax/__init__.py). The _nv alias registration and load_plugins() call are torch-specific constructs, but this block will also execute when the JAX extension is loaded. This registers a meaningless transformer_engine_jax_nv entry in sys.modules and re-invokes load_plugins() a second time, which may double-register backends or trigger unintended side effects.

Comment thread transformer_engine/common/__init__.py Outdated
Comment on lines +201 to +202
except ImportError as e:
print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent print to stdout is easy to miss in production logging pipelines. Using warnings.warn (or at minimum writing to sys.stderr) ensures the failure is visible under standard Python warning filters and won't be swallowed by log-capture tools.

Suggested change
except ImportError as e:
print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}")
except ImportError as e:
import warnings
warnings.warn(
f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}",
RuntimeWarning,
stacklevel=2,
)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment thread transformer_engine/common/__init__.py Outdated
Comment on lines +199 to +200
from plugin import load_plugins
load_plugins()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the load_plugins performing a generic plugin discovery and loading? If so, we could have that in the main TE repo as well. If it is specific to FlagOS then I agree that it should live there.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. Two questions I'd like to confirm before deciding on the next step:

  • Should the load_plugins entrypoint be merged into the main branch, so it can follow future NVIDIA TE releases?
  • Should the generic framework components in transformer_engine_plugin_fl (e.g., OpManager, SelectionPolicy) also be upstreamed into TE, or should we keep the current approach where TE only exposes a minimal load_plugins hook and all dispatch logic stays external?

@ptrendx

ptrendx commented Jun 9, 2026

Copy link
Copy Markdown
Member

Hi @lxd-cumt , thank you for the contribution. A few things:

@lxd-cumt lxd-cumt force-pushed the rc2.14_with_flagos branch from 1886932 to 727a116 Compare June 10, 2026 09:21
lxd-cumt added 2 commits June 15, 2026 17:25
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
@lxd-cumt lxd-cumt force-pushed the rc2.14_with_flagos branch from 727a116 to c590d4d Compare June 15, 2026 09:26
@lxd-cumt

lxd-cumt commented Jun 15, 2026

Copy link
Copy Markdown
Author

Hi @lxd-cumt , thank you for the contribution. A few things:

Thanks for the review! I've signed off on both commits and addressed the Greptile comments.

@lxd-cumt lxd-cumt closed this Jun 15, 2026
@lxd-cumt lxd-cumt reopened this Jun 15, 2026
Comment thread transformer_engine/common/__init__.py Outdated
Comment on lines +199 to +208
sys.modules[module_name + "_nv"] = solib
try:
from transformer_engine_plugin_fl import load_plugins
load_plugins()
except ImportError as e:
warnings.warn(
f"NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}",
ImportWarning,
stacklevel=2,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Plugin initialization errors beyond ImportError left uncaught

sys.modules[module_name + "_nv"] is registered unconditionally at line 199 before load_plugins() runs. If transformer_engine_plugin_fl is installed but load_plugins() itself raises a non-ImportError (e.g., RuntimeError during backend registration, an AttributeError inside the plugin, or an OSError loading a shared library), the exception escapes the try/except block, crashes load_framework_extension, and leaves sys.modules in a partially inconsistent state: transformer_engine_torch_nv exists (native pybind) but transformer_engine_torch was never replaced by the plugin stub. TE initialization fails with an opaque traceback instead of the intended graceful fallback message.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
@lxd-cumt lxd-cumt force-pushed the rc2.14_with_flagos branch from 3cbb72d to 9a869ec Compare June 15, 2026 09:54
Comment on lines +199 to +210
try:
from transformer_engine_plugin_fl import load_plugins
sys.modules[module_name + "_nv"] = solib
load_plugins()
except Exception as e:
# Rollback _nv registration if plugin failed to fully initialize
sys.modules.pop(module_name + "_nv", None)
warnings.warn(
f"NVTE_ENABLE_PLUGIN=1 but plugin loading failed: {e}",
RuntimeWarning,
stacklevel=2,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incomplete sys.modules rollback on load_plugins() failure

sys.modules[module_name] (i.e., transformer_engine_torch) is set to solib at line 192, before this try block. If load_plugins() partially succeeds — for example, replaces sys.modules["transformer_engine_torch"] with the plugin stub before raising a RuntimeError during backend registration — the except block pops _nv but leaves sys.modules["transformer_engine_torch"] pointing to the partially-initialized stub. TE then continues with a broken tex module even though the warning says plugin loading failed.

Capture the pre-attempt value of sys.modules.get(module_name) before calling load_plugins() and restore it in the except block alongside the _nv pop.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
@lxd-cumt

Copy link
Copy Markdown
Author

The vermin check in CI is failing with AttributeError: module 'ast' has no attribute 'Str' because the CI environment runs Python 3.14, but the pinned vermin rev (c75aca72) uses ast.Str which was removed in Python 3.12+. I noticed that the upstream main branch has already updated the vermin rev to b70ff9611a01a2bf2f702aa537d14e71e330edba. Should I pick up this fix, or is this check non-blocking?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants