Skip to content

Commit 1543ce5

Browse files
committed
Supress stdin to not leave things hanging
1 parent 142b116 commit 1543ce5

5 files changed

Lines changed: 123 additions & 116 deletions

File tree

Lib/profiling/sampling/live_collector/collector.py

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import contextlib
55
import curses
6+
from dataclasses import dataclass, field
67
import os
78
import site
89
import sys
@@ -43,6 +44,29 @@
4344
from .widgets import HeaderWidget, TableWidget, FooterWidget, HelpWidget
4445

4546

47+
@dataclass
48+
class ThreadData:
49+
"""Encapsulates all profiling data for a single thread."""
50+
51+
thread_id: int
52+
53+
# Function call statistics: {location: {direct_calls: int, cumulative_calls: int}}
54+
result: dict = field(default_factory=lambda: collections.defaultdict(
55+
lambda: dict(direct_calls=0, cumulative_calls=0)
56+
))
57+
58+
# Thread status statistics
59+
has_gil: int = 0
60+
on_cpu: int = 0
61+
gil_requested: int = 0
62+
unknown: int = 0
63+
total: int = 0 # Total status samples for this thread
64+
65+
# Sample counts
66+
sample_count: int = 0
67+
gc_frame_samples: int = 0
68+
69+
4670
class LiveStatsCollector(Collector):
4771
"""Collector that displays live top-like statistics using ncurses."""
4872

@@ -118,10 +142,7 @@ def __init__(
118142
self.current_thread_index = (
119143
0 # Index into thread_ids when in PER_THREAD mode
120144
)
121-
self.per_thread_result = {} # {thread_id: {func: {direct_calls, cumulative_calls}}}
122-
self.per_thread_status = {} # {thread_id: {has_gil: count, on_cpu: count, ...}}
123-
self.per_thread_samples = {} # {thread_id: sample_count}
124-
self.per_thread_gc_samples = {} # {thread_id: gc_frame_sample_count}
145+
self.per_thread_data = {} # {thread_id: ThreadData}
125146

126147
# Calculate common path prefixes to strip
127148
self._path_prefixes = self._get_common_path_prefixes()
@@ -201,15 +222,9 @@ def _process_frames(self, frames, thread_id=None):
201222

202223
# Also track per-thread if thread_id is provided
203224
if thread_id is not None:
204-
if thread_id not in self.per_thread_result:
205-
self.per_thread_result[thread_id] = (
206-
collections.defaultdict(
207-
lambda: dict(direct_calls=0, cumulative_calls=0)
208-
)
209-
)
210-
self.per_thread_result[thread_id][location][
211-
"cumulative_calls"
212-
] += 1
225+
if thread_id not in self.per_thread_data:
226+
self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
227+
self.per_thread_data[thread_id].result[location]["cumulative_calls"] += 1
213228

214229
# The top frame gets counted as an inline call (directly executing)
215230
top_location = (
@@ -221,13 +236,9 @@ def _process_frames(self, frames, thread_id=None):
221236

222237
# Also track per-thread
223238
if thread_id is not None:
224-
if thread_id not in self.per_thread_result:
225-
self.per_thread_result[thread_id] = collections.defaultdict(
226-
lambda: dict(direct_calls=0, cumulative_calls=0)
227-
)
228-
self.per_thread_result[thread_id][top_location][
229-
"direct_calls"
230-
] += 1
239+
if thread_id not in self.per_thread_data:
240+
self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
241+
self.per_thread_data[thread_id].result[top_location]["direct_calls"] += 1
231242

232243
def collect_failed_sample(self):
233244
self._failed_samples += 1
@@ -260,47 +271,31 @@ def collect(self, stack_frames):
260271
status_flags = getattr(thread_info, "status", 0)
261272
thread_id = getattr(thread_info, "thread_id", None)
262273

263-
# Initialize per-thread status tracking
264-
if (
265-
thread_id is not None
266-
and thread_id not in self.per_thread_status
267-
):
268-
self.per_thread_status[thread_id] = {
269-
"has_gil": 0,
270-
"on_cpu": 0,
271-
"gil_requested": 0,
272-
"unknown": 0,
273-
"total": 0,
274-
}
274+
# Initialize per-thread data if needed
275+
if thread_id is not None and thread_id not in self.per_thread_data:
276+
self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
275277

276278
# Update aggregated counts
277279
if status_flags & THREAD_STATUS_HAS_GIL:
278280
temp_status_counts["has_gil"] += 1
279281
if thread_id is not None:
280-
self.per_thread_status[thread_id]["has_gil"] += 1
282+
self.per_thread_data[thread_id].has_gil += 1
281283
if status_flags & THREAD_STATUS_ON_CPU:
282284
temp_status_counts["on_cpu"] += 1
283285
if thread_id is not None:
284-
self.per_thread_status[thread_id]["on_cpu"] += 1
286+
self.per_thread_data[thread_id].on_cpu += 1
285287
if status_flags & THREAD_STATUS_GIL_REQUESTED:
286288
temp_status_counts["gil_requested"] += 1
287289
if thread_id is not None:
288-
self.per_thread_status[thread_id]["gil_requested"] += 1
290+
self.per_thread_data[thread_id].gil_requested += 1
289291
if status_flags & THREAD_STATUS_UNKNOWN:
290292
temp_status_counts["unknown"] += 1
291293
if thread_id is not None:
292-
self.per_thread_status[thread_id]["unknown"] += 1
294+
self.per_thread_data[thread_id].unknown += 1
293295

294296
# Update per-thread total count
295297
if thread_id is not None:
296-
self.per_thread_status[thread_id]["total"] += 1
297-
298-
# Initialize per-thread sample tracking
299-
if thread_id is not None:
300-
if thread_id not in self.per_thread_samples:
301-
self.per_thread_samples[thread_id] = 0
302-
if thread_id not in self.per_thread_gc_samples:
303-
self.per_thread_gc_samples[thread_id] = 0
298+
self.per_thread_data[thread_id].total += 1
304299

305300
# Process frames (respecting skip_idle)
306301
if self.skip_idle:
@@ -322,7 +317,7 @@ def collect(self, stack_frames):
322317

323318
# Increment per-thread sample count
324319
if thread_id is not None:
325-
self.per_thread_samples[thread_id] += 1
320+
self.per_thread_data[thread_id].sample_count += 1
326321

327322
# Check if any frame is in GC
328323
thread_has_gc_frame = False
@@ -335,7 +330,7 @@ def collect(self, stack_frames):
335330

336331
# Track per-thread GC samples
337332
if thread_has_gc_frame and thread_id is not None:
338-
self.per_thread_gc_samples[thread_id] += 1
333+
self.per_thread_data[thread_id].gc_frame_samples += 1
339334

340335
# Update cumulative thread status counts
341336
for key, count in temp_status_counts.items():
@@ -582,7 +577,10 @@ def _build_stats_list(self):
582577
# PER_THREAD mode - use specific thread result
583578
if self.current_thread_index < len(self.thread_ids):
584579
thread_id = self.thread_ids[self.current_thread_index]
585-
result_source = self.per_thread_result.get(thread_id, {})
580+
if thread_id in self.per_thread_data:
581+
result_source = self.per_thread_data[thread_id].result
582+
else:
583+
result_source = {}
586584
else:
587585
result_source = self.result
588586

@@ -648,10 +646,7 @@ def _build_stats_list(self):
648646
def reset_stats(self):
649647
"""Reset all collected statistics."""
650648
self.result.clear()
651-
self.per_thread_result.clear()
652-
self.per_thread_status.clear()
653-
self.per_thread_samples.clear()
654-
self.per_thread_gc_samples.clear()
649+
self.per_thread_data.clear()
655650
self.thread_ids.clear()
656651
self.view_mode = "ALL"
657652
self.current_thread_index = 0

Lib/profiling/sampling/live_collector/widgets.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,15 @@ def draw_thread_status(self, line, width):
391391
thread_id = self.collector.thread_ids[
392392
self.collector.current_thread_index
393393
]
394-
if thread_id in self.collector.per_thread_status:
395-
status_counts = self.collector.per_thread_status[thread_id]
394+
if thread_id in self.collector.per_thread_data:
395+
thread_data = self.collector.per_thread_data[thread_id]
396+
status_counts = {
397+
"has_gil": thread_data.has_gil,
398+
"on_cpu": thread_data.on_cpu,
399+
"gil_requested": thread_data.gil_requested,
400+
"unknown": thread_data.unknown,
401+
"total": thread_data.total,
402+
}
396403
else:
397404
status_counts = self.collector._thread_status_counts
398405
else:
@@ -420,12 +427,13 @@ def draw_thread_status(self, line, width):
420427
thread_id = self.collector.thread_ids[
421428
self.collector.current_thread_index
422429
]
423-
thread_samples = self.collector.per_thread_samples.get(
424-
thread_id, 1
425-
)
426-
thread_gc_samples = self.collector.per_thread_gc_samples.get(
427-
thread_id, 0
428-
)
430+
if thread_id in self.collector.per_thread_data:
431+
thread_data = self.collector.per_thread_data[thread_id]
432+
thread_samples = thread_data.sample_count
433+
thread_gc_samples = thread_data.gc_frame_samples
434+
else:
435+
thread_samples = 1
436+
thread_gc_samples = 0
429437
total_samples = max(1, thread_samples)
430438
pct_gc = (thread_gc_samples / total_samples) * 100
431439
else:
@@ -485,7 +493,10 @@ def draw_function_stats(self, line, width, stats_list):
485493
if self.collector.view_mode == "PER_THREAD" and len(self.collector.thread_ids) > 0:
486494
if self.collector.current_thread_index < len(self.collector.thread_ids):
487495
thread_id = self.collector.thread_ids[self.collector.current_thread_index]
488-
result_set = self.collector.per_thread_result.get(thread_id, {})
496+
if thread_id in self.collector.per_thread_data:
497+
result_set = self.collector.per_thread_data[thread_id].result
498+
else:
499+
result_set = {}
489500
else:
490501
result_set = self.collector.result
491502
else:

Lib/profiling/sampling/sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _run_with_sync(original_cmd, suppress_output=False):
127127
# Suppress stdout/stderr if requested (for live mode)
128128
popen_kwargs = {}
129129
if suppress_output:
130+
popen_kwargs['stdin'] = subprocess.DEVNULL
130131
popen_kwargs['stdout'] = subprocess.DEVNULL
131132
popen_kwargs['stderr'] = subprocess.DEVNULL
132133

Lib/test/test_profiling/test_live_collector_core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ def test_process_frames_with_thread_id(self):
120120
self.assertEqual(collector.result[location]["cumulative_calls"], 1)
121121

122122
# Check per-thread result
123-
self.assertIn(123, collector.per_thread_result)
123+
self.assertIn(123, collector.per_thread_data)
124124
self.assertEqual(
125-
collector.per_thread_result[123][location]["direct_calls"], 1
125+
collector.per_thread_data[123].result[location]["direct_calls"], 1
126126
)
127127
self.assertEqual(
128-
collector.per_thread_result[123][location]["cumulative_calls"], 1
128+
collector.per_thread_data[123].result[location]["cumulative_calls"], 1
129129
)
130130

131131
def test_process_frames_multiple_threads(self):
@@ -139,23 +139,23 @@ def test_process_frames_multiple_threads(self):
139139
collector._process_frames(frames2, thread_id=456)
140140

141141
# Check that both threads have their own data
142-
self.assertIn(123, collector.per_thread_result)
143-
self.assertIn(456, collector.per_thread_result)
142+
self.assertIn(123, collector.per_thread_data)
143+
self.assertIn(456, collector.per_thread_data)
144144

145145
loc1 = ("test.py", 10, "test_func")
146146
loc2 = ("test.py", 20, "other_func")
147147

148148
# Thread 123 should only have func1
149149
self.assertEqual(
150-
collector.per_thread_result[123][loc1]["direct_calls"], 1
150+
collector.per_thread_data[123].result[loc1]["direct_calls"], 1
151151
)
152-
self.assertNotIn(loc2, collector.per_thread_result[123])
152+
self.assertNotIn(loc2, collector.per_thread_data[123].result)
153153

154154
# Thread 456 should only have func2
155155
self.assertEqual(
156-
collector.per_thread_result[456][loc2]["direct_calls"], 1
156+
collector.per_thread_data[456].result[loc2]["direct_calls"], 1
157157
)
158-
self.assertNotIn(loc1, collector.per_thread_result[456])
158+
self.assertNotIn(loc1, collector.per_thread_data[456].result)
159159

160160

161161
class TestLiveStatsCollectorCollect(unittest.TestCase):

0 commit comments

Comments
 (0)