Skip to content

Commit 142b116

Browse files
committed
Implement per-thread switcher
1 parent 20c4d27 commit 142b116

7 files changed

Lines changed: 628 additions & 99 deletions

File tree

Lib/profiling/sampling/collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class Collector(ABC):
1111
def collect(self, stack_frames):
1212
"""Collect profiling data from stack frames."""
1313

14-
def collect_failed_sample(self, exeption):
14+
def collect_failed_sample(self):
1515
"""Collect data about a failed sample attempt."""
1616

1717
@abstractmethod

Lib/profiling/sampling/live_collector/collector.py

Lines changed: 168 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ def __init__(
9696

9797
# Thread status statistics (bit flags)
9898
self._thread_status_counts = {
99-
'has_gil': 0,
100-
'on_cpu': 0,
101-
'gil_requested': 0,
102-
'unknown': 0,
103-
'total': 0, # Total thread count across all samples
99+
"has_gil": 0,
100+
"on_cpu": 0,
101+
"gil_requested": 0,
102+
"unknown": 0,
103+
"total": 0, # Total thread count across all samples
104104
}
105105
self._gc_frame_samples = 0 # Track samples with GC frames
106106

@@ -112,6 +112,17 @@ def __init__(
112112
self.filter_input_buffer = "" # Buffer for filter input
113113
self.finished = False # Program has finished, showing final state
114114

115+
# Thread tracking state
116+
self.thread_ids = [] # List of thread IDs seen
117+
self.view_mode = "ALL" # "ALL" or "PER_THREAD"
118+
self.current_thread_index = (
119+
0 # Index into thread_ids when in PER_THREAD mode
120+
)
121+
self.per_thread_result = {} # {thread_id: {func: {direct_calls, cumulative_calls}}}
122+
self.per_thread_status = {} # {thread_id: {has_gil: count, on_cpu: count, ...}}
123+
self.per_thread_samples = {} # {thread_id: sample_count}
124+
self.per_thread_gc_samples = {} # {thread_id: gc_frame_sample_count}
125+
115126
# Calculate common path prefixes to strip
116127
self._path_prefixes = self._get_common_path_prefixes()
117128

@@ -173,8 +184,13 @@ def _simplify_path(self, filepath):
173184
# If no match, return the original path
174185
return filepath
175186

176-
def _process_frames(self, frames):
177-
"""Process a single thread's frame stack."""
187+
def _process_frames(self, frames, thread_id=None):
188+
"""Process a single thread's frame stack.
189+
190+
Args:
191+
frames: List of frame information
192+
thread_id: Thread ID for per-thread tracking (optional)
193+
"""
178194
if not frames:
179195
return
180196

@@ -183,6 +199,18 @@ def _process_frames(self, frames):
183199
location = (frame.filename, frame.lineno, frame.funcname)
184200
self.result[location]["cumulative_calls"] += 1
185201

202+
# Also track per-thread if thread_id is provided
203+
if thread_id is not None:
204+
if thread_id not in self.per_thread_result:
205+
self.per_thread_result[thread_id] = (
206+
collections.defaultdict(
207+
lambda: dict(direct_calls=0, cumulative_calls=0)
208+
)
209+
)
210+
self.per_thread_result[thread_id][location][
211+
"cumulative_calls"
212+
] += 1
213+
186214
# The top frame gets counted as an inline call (directly executing)
187215
top_location = (
188216
frames[0].filename,
@@ -191,6 +219,16 @@ def _process_frames(self, frames):
191219
)
192220
self.result[top_location]["direct_calls"] += 1
193221

222+
# Also track per-thread
223+
if thread_id is not None:
224+
if thread_id not in self.per_thread_result:
225+
self.per_thread_result[thread_id] = collections.defaultdict(
226+
lambda: dict(direct_calls=0, cumulative_calls=0)
227+
)
228+
self.per_thread_result[thread_id][top_location][
229+
"direct_calls"
230+
] += 1
231+
194232
def collect_failed_sample(self):
195233
self._failed_samples += 1
196234
self.total_samples += 1
@@ -203,32 +241,66 @@ def collect(self, stack_frames):
203241

204242
# Thread status counts for this sample
205243
temp_status_counts = {
206-
'has_gil': 0,
207-
'on_cpu': 0,
208-
'gil_requested': 0,
209-
'unknown': 0,
210-
'total': 0,
244+
"has_gil": 0,
245+
"on_cpu": 0,
246+
"gil_requested": 0,
247+
"unknown": 0,
248+
"total": 0,
211249
}
212250
has_gc_frame = False
213251

214252
# Always collect data, even when paused
215253
# Track thread status flags and GC frames
216254
for interpreter_info in stack_frames:
217-
threads = getattr(interpreter_info, 'threads', [])
255+
threads = getattr(interpreter_info, "threads", [])
218256
for thread_info in threads:
219-
temp_status_counts['total'] += 1
257+
temp_status_counts["total"] += 1
220258

221259
# Track thread status using bit flags
222-
status_flags = getattr(thread_info, 'status', 0)
223-
260+
status_flags = getattr(thread_info, "status", 0)
261+
thread_id = getattr(thread_info, "thread_id", None)
262+
263+
# Initialize per-thread status tracking
264+
if (
265+
thread_id is not None
266+
and thread_id not in self.per_thread_status
267+
):
268+
self.per_thread_status[thread_id] = {
269+
"has_gil": 0,
270+
"on_cpu": 0,
271+
"gil_requested": 0,
272+
"unknown": 0,
273+
"total": 0,
274+
}
275+
276+
# Update aggregated counts
224277
if status_flags & THREAD_STATUS_HAS_GIL:
225-
temp_status_counts['has_gil'] += 1
278+
temp_status_counts["has_gil"] += 1
279+
if thread_id is not None:
280+
self.per_thread_status[thread_id]["has_gil"] += 1
226281
if status_flags & THREAD_STATUS_ON_CPU:
227-
temp_status_counts['on_cpu'] += 1
282+
temp_status_counts["on_cpu"] += 1
283+
if thread_id is not None:
284+
self.per_thread_status[thread_id]["on_cpu"] += 1
228285
if status_flags & THREAD_STATUS_GIL_REQUESTED:
229-
temp_status_counts['gil_requested'] += 1
286+
temp_status_counts["gil_requested"] += 1
287+
if thread_id is not None:
288+
self.per_thread_status[thread_id]["gil_requested"] += 1
230289
if status_flags & THREAD_STATUS_UNKNOWN:
231-
temp_status_counts['unknown'] += 1
290+
temp_status_counts["unknown"] += 1
291+
if thread_id is not None:
292+
self.per_thread_status[thread_id]["unknown"] += 1
293+
294+
# Update per-thread total count
295+
if thread_id is not None:
296+
self.per_thread_status[thread_id]["total"] += 1
297+
298+
# Initialize per-thread sample tracking
299+
if thread_id is not None:
300+
if thread_id not in self.per_thread_samples:
301+
self.per_thread_samples[thread_id] = 0
302+
if thread_id not in self.per_thread_gc_samples:
303+
self.per_thread_gc_samples[thread_id] = 0
232304

233305
# Process frames (respecting skip_idle)
234306
if self.skip_idle:
@@ -237,16 +309,34 @@ def collect(self, stack_frames):
237309
if not (has_gil or on_cpu):
238310
continue
239311

240-
frames = getattr(thread_info, 'frame_info', None)
312+
frames = getattr(thread_info, "frame_info", None)
241313
if frames:
242-
self._process_frames(frames)
314+
self._process_frames(frames, thread_id=thread_id)
315+
316+
# Track thread IDs only for threads that actually have samples
317+
if (
318+
thread_id is not None
319+
and thread_id not in self.thread_ids
320+
):
321+
self.thread_ids.append(thread_id)
322+
323+
# Increment per-thread sample count
324+
if thread_id is not None:
325+
self.per_thread_samples[thread_id] += 1
326+
243327
# Check if any frame is in GC
328+
thread_has_gc_frame = False
244329
for frame in frames:
245-
funcname = getattr(frame, 'funcname', '')
246-
if '<GC>' in funcname or 'gc_collect' in funcname:
330+
funcname = getattr(frame, "funcname", "")
331+
if "<GC>" in funcname or "gc_collect" in funcname:
247332
has_gc_frame = True
333+
thread_has_gc_frame = True
248334
break
249335

336+
# Track per-thread GC samples
337+
if thread_has_gc_frame and thread_id is not None:
338+
self.per_thread_gc_samples[thread_id] += 1
339+
250340
# Update cumulative thread status counts
251341
for key, count in temp_status_counts.items():
252342
self._thread_status_counts[key] += count
@@ -483,7 +573,20 @@ def _setup_colors(self):
483573
def _build_stats_list(self):
484574
"""Build and sort the statistics list."""
485575
stats_list = []
486-
for func, call_counts in self.result.items():
576+
577+
# Determine which data source to use based on view mode
578+
if self.view_mode == "ALL":
579+
# ALL threads - use aggregated result
580+
result_source = self.result
581+
else:
582+
# PER_THREAD mode - use specific thread result
583+
if self.current_thread_index < len(self.thread_ids):
584+
thread_id = self.thread_ids[self.current_thread_index]
585+
result_source = self.per_thread_result.get(thread_id, {})
586+
else:
587+
result_source = self.result
588+
589+
for func, call_counts in result_source.items():
487590
# Apply filter if set (using substring matching)
488591
if self.filter_pattern:
489592
filename, lineno, funcname = func
@@ -501,8 +604,8 @@ def _build_stats_list(self):
501604
if not matched:
502605
continue
503606

504-
direct_calls = call_counts["direct_calls"]
505-
cumulative_calls = call_counts["cumulative_calls"]
607+
direct_calls = call_counts.get("direct_calls", 0)
608+
cumulative_calls = call_counts.get("cumulative_calls", 0)
506609
total_time = direct_calls * self.sample_interval_sec
507610
cumulative_time = cumulative_calls * self.sample_interval_sec
508611

@@ -545,16 +648,23 @@ def _build_stats_list(self):
545648
def reset_stats(self):
546649
"""Reset all collected statistics."""
547650
self.result.clear()
651+
self.per_thread_result.clear()
652+
self.per_thread_status.clear()
653+
self.per_thread_samples.clear()
654+
self.per_thread_gc_samples.clear()
655+
self.thread_ids.clear()
656+
self.view_mode = "ALL"
657+
self.current_thread_index = 0
548658
self.total_samples = 0
549659
self._successful_samples = 0
550660
self._failed_samples = 0
551661
self._max_sample_rate = 0
552662
self._thread_status_counts = {
553-
'has_gil': 0,
554-
'on_cpu': 0,
555-
'gil_requested': 0,
556-
'unknown': 0,
557-
'total': 0,
663+
"has_gil": 0,
664+
"on_cpu": 0,
665+
"gil_requested": 0,
666+
"unknown": 0,
667+
"total": 0,
558668
}
559669
self._gc_frame_samples = 0
560670
self.start_time = time.perf_counter()
@@ -718,7 +828,9 @@ def _handle_input(self):
718828

719829
elif ch == ord("-") or ch == ord("_"):
720830
# Increase update interval (slower refresh)
721-
new_interval = min(1.0, constants.DISPLAY_UPDATE_INTERVAL + 0.05) # Max 1Hz
831+
new_interval = min(
832+
1.0, constants.DISPLAY_UPDATE_INTERVAL + 0.05
833+
) # Max 1Hz
722834
constants.DISPLAY_UPDATE_INTERVAL = new_interval
723835

724836
elif ch == ord("c") or ch == ord("C"):
@@ -729,6 +841,29 @@ def _handle_input(self):
729841
self.filter_input_mode = True
730842
self.filter_input_buffer = self.filter_pattern or ""
731843

844+
elif ch == ord("t") or ch == ord("T"):
845+
# Toggle between ALL and PER_THREAD modes
846+
if self.view_mode == "ALL":
847+
if len(self.thread_ids) > 0:
848+
self.view_mode = "PER_THREAD"
849+
self.current_thread_index = 0
850+
else:
851+
self.view_mode = "ALL"
852+
853+
elif ch == curses.KEY_LEFT or ch == curses.KEY_UP:
854+
# Navigate to previous thread in PER_THREAD mode
855+
if self.view_mode == "PER_THREAD" and len(self.thread_ids) > 0:
856+
self.current_thread_index = (
857+
self.current_thread_index - 1
858+
) % len(self.thread_ids)
859+
860+
elif ch == curses.KEY_RIGHT or ch == curses.KEY_DOWN:
861+
# Navigate to next thread in PER_THREAD mode
862+
if self.view_mode == "PER_THREAD" and len(self.thread_ids) > 0:
863+
self.current_thread_index = (
864+
self.current_thread_index + 1
865+
) % len(self.thread_ids)
866+
732867
def init_curses(self, stdscr):
733868
"""Initialize curses display and suppress stdout/stderr."""
734869
self.stdscr = stdscr

0 commit comments

Comments
 (0)