Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions launch/utilities/logger.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
"""
Logging utilities for launch operations with file and console output.
"""
import io, sys
import logging
from pathlib import Path

from rich.logging import RichHandler

import io, sys
from rich.console import Console


# https://github.com/microsoft/RepoLaunch/pull/31
_shared_console: Console | None = None
def _get_shared_console() -> Console:
global _shared_console
if _shared_console is None:
# Wrap stdout in a UTF-8 TextIOWrapper; replace any bad chars defensively
utf8_stdout = io.TextIOWrapper(
sys.stdout.buffer, encoding="utf-8", errors="replace"
)
_shared_console = Console(file=utf8_stdout, soft_wrap=True)
return _shared_console


def setup_logger(instance_id: str, log_file: Path | list[Path], printing: bool = True) -> logging.Logger:
"""
Setup logger with file and optional console output for an instance.
Expand Down Expand Up @@ -43,10 +54,7 @@ def setup_logger(instance_id: str, log_file: Path | list[Path], printing: bool =
# logger.addHandler(ch)
# add rich handler
if printing:
# Wrap stdout in a UTF-8 TextIOWrapper; replace any bad chars defensively
utf8_stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
console = Console(file=utf8_stdout, soft_wrap=True)
rh = RichHandler(console=console, rich_tracebacks=True, show_path=False)
rh = RichHandler(console=_get_shared_console(), rich_tracebacks=True, show_path=False)
rh.setLevel(logging.INFO)
logger.addHandler(rh)
return logger
Expand Down
72 changes: 72 additions & 0 deletions tests/logger_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import gc
import io
import sys

from launch.utilities import logger as logger_module


class FakeStdout:
def __init__(self):
self.buffer = io.BytesIO()
self.encoding = "utf-8"
self.errors = "replace"

def write(self, text: str) -> int:
return len(text)

def flush(self) -> None:
pass

def isatty(self) -> bool:
return False


def _drop_shared_console_without_closing_buffer() -> None:
console = getattr(logger_module, "_shared_console", None)
stream = getattr(console, "file", None)
if stream is not None and hasattr(stream, "detach"):
try:
stream.detach()
except (ValueError, OSError, io.UnsupportedOperation):
pass

if hasattr(logger_module, "_shared_console"):
logger_module._shared_console = None


def test_repeated_console_logger_setup_keeps_stdout_buffer_open(
tmp_path, monkeypatch
):
'''
Issue #30
FAIL_TO_PASS at PR#31
'''
logger_name = "test_repeated_console_logger_setup"
fake_stdout = FakeStdout()

_drop_shared_console_without_closing_buffer()
monkeypatch.setattr(sys, "stdout", fake_stdout)

try:
setup_logger = logger_module.setup_logger(
logger_name, tmp_path / "setup.log", printing=True
)
setup_logger.info("setup stage")
logger_module.clean_logger(setup_logger)
del setup_logger
gc.collect()

assert not fake_stdout.buffer.closed

organize_logger = logger_module.setup_logger(
logger_name, tmp_path / "organize.log", printing=True
)
organize_logger.info("organize stage")
logger_module.clean_logger(organize_logger)
del organize_logger
gc.collect()

assert not fake_stdout.buffer.closed
finally:
logger_module.clean_logger(logger_name)
_drop_shared_console_without_closing_buffer()