Skip to content

Commit 00dc8e7

Browse files
committed
Simplify
1 parent 1543ce5 commit 00dc8e7

2 files changed

Lines changed: 68 additions & 122 deletions

File tree

Lib/profiling/sampling/live_collector/collector.py

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ class ThreadData:
6666
sample_count: int = 0
6767
gc_frame_samples: int = 0
6868

69+
def increment_status_flag(self, status_flags):
70+
"""Update status counts based on status bit flags."""
71+
if status_flags & THREAD_STATUS_HAS_GIL:
72+
self.has_gil += 1
73+
if status_flags & THREAD_STATUS_ON_CPU:
74+
self.on_cpu += 1
75+
if status_flags & THREAD_STATUS_GIL_REQUESTED:
76+
self.gil_requested += 1
77+
if status_flags & THREAD_STATUS_UNKNOWN:
78+
self.unknown += 1
79+
self.total += 1
80+
81+
def as_status_dict(self):
82+
"""Return status counts as a dict for compatibility."""
83+
return {
84+
"has_gil": self.has_gil,
85+
"on_cpu": self.on_cpu,
86+
"gil_requested": self.gil_requested,
87+
"unknown": self.unknown,
88+
"total": self.total,
89+
}
90+
6991

7092
class LiveStatsCollector(Collector):
7193
"""Collector that displays live top-like statistics using ncurses."""
@@ -156,6 +178,26 @@ def __init__(
156178
# Color mode
157179
self._can_colorize = _colorize.can_colorize()
158180

181+
def _get_or_create_thread_data(self, thread_id):
182+
"""Get or create ThreadData for a thread ID."""
183+
if thread_id not in self.per_thread_data:
184+
self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
185+
return self.per_thread_data[thread_id]
186+
187+
def _get_current_thread_data(self):
188+
"""Get ThreadData for currently selected thread in PER_THREAD mode."""
189+
if self.view_mode == "PER_THREAD" and self.current_thread_index < len(self.thread_ids):
190+
thread_id = self.thread_ids[self.current_thread_index]
191+
return self.per_thread_data.get(thread_id)
192+
return None
193+
194+
def _get_current_result_source(self):
195+
"""Get result dict for current view mode (aggregated or per-thread)."""
196+
if self.view_mode == "ALL":
197+
return self.result
198+
thread_data = self._get_current_thread_data()
199+
return thread_data.result if thread_data else {}
200+
159201
def _get_common_path_prefixes(self):
160202
"""Get common path prefixes to strip from file paths."""
161203
prefixes = []
@@ -215,30 +257,21 @@ def _process_frames(self, frames, thread_id=None):
215257
if not frames:
216258
return
217259

260+
# Get per-thread data if tracking per-thread
261+
thread_data = self._get_or_create_thread_data(thread_id) if thread_id is not None else None
262+
218263
# Process each frame in the stack to track cumulative calls
219264
for frame in frames:
220265
location = (frame.filename, frame.lineno, frame.funcname)
221266
self.result[location]["cumulative_calls"] += 1
222-
223-
# Also track per-thread if thread_id is provided
224-
if thread_id is not None:
225-
if thread_id not in self.per_thread_data:
226-
self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
227-
self.per_thread_data[thread_id].result[location]["cumulative_calls"] += 1
267+
if thread_data:
268+
thread_data.result[location]["cumulative_calls"] += 1
228269

229270
# The top frame gets counted as an inline call (directly executing)
230-
top_location = (
231-
frames[0].filename,
232-
frames[0].lineno,
233-
frames[0].funcname,
234-
)
271+
top_location = (frames[0].filename, frames[0].lineno, frames[0].funcname)
235272
self.result[top_location]["direct_calls"] += 1
236-
237-
# Also track per-thread
238-
if thread_id is not None:
239-
if thread_id not in self.per_thread_data:
240-
self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
241-
self.per_thread_data[thread_id].result[top_location]["direct_calls"] += 1
273+
if thread_data:
274+
thread_data.result[top_location]["direct_calls"] += 1
242275

243276
def collect_failed_sample(self):
244277
self._failed_samples += 1
@@ -271,31 +304,20 @@ def collect(self, stack_frames):
271304
status_flags = getattr(thread_info, "status", 0)
272305
thread_id = getattr(thread_info, "thread_id", None)
273306

274-
# Initialize per-thread data if needed
275-
if thread_id is not None and thread_id not in self.per_thread_data:
276-
self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
277-
278307
# Update aggregated counts
279308
if status_flags & THREAD_STATUS_HAS_GIL:
280309
temp_status_counts["has_gil"] += 1
281-
if thread_id is not None:
282-
self.per_thread_data[thread_id].has_gil += 1
283310
if status_flags & THREAD_STATUS_ON_CPU:
284311
temp_status_counts["on_cpu"] += 1
285-
if thread_id is not None:
286-
self.per_thread_data[thread_id].on_cpu += 1
287312
if status_flags & THREAD_STATUS_GIL_REQUESTED:
288313
temp_status_counts["gil_requested"] += 1
289-
if thread_id is not None:
290-
self.per_thread_data[thread_id].gil_requested += 1
291314
if status_flags & THREAD_STATUS_UNKNOWN:
292315
temp_status_counts["unknown"] += 1
293-
if thread_id is not None:
294-
self.per_thread_data[thread_id].unknown += 1
295316

296-
# Update per-thread total count
317+
# Update per-thread status counts
297318
if thread_id is not None:
298-
self.per_thread_data[thread_id].total += 1
319+
thread_data = self._get_or_create_thread_data(thread_id)
320+
thread_data.increment_status_flag(status_flags)
299321

300322
# Process frames (respecting skip_idle)
301323
if self.skip_idle:
@@ -315,11 +337,7 @@ def collect(self, stack_frames):
315337
):
316338
self.thread_ids.append(thread_id)
317339

318-
# Increment per-thread sample count
319-
if thread_id is not None:
320-
self.per_thread_data[thread_id].sample_count += 1
321-
322-
# Check if any frame is in GC
340+
# Increment per-thread sample count and check for GC frames
323341
thread_has_gc_frame = False
324342
for frame in frames:
325343
funcname = getattr(frame, "funcname", "")
@@ -328,9 +346,11 @@ def collect(self, stack_frames):
328346
thread_has_gc_frame = True
329347
break
330348

331-
# Track per-thread GC samples
332-
if thread_has_gc_frame and thread_id is not None:
333-
self.per_thread_data[thread_id].gc_frame_samples += 1
349+
if thread_id is not None:
350+
thread_data = self._get_or_create_thread_data(thread_id)
351+
thread_data.sample_count += 1
352+
if thread_has_gc_frame:
353+
thread_data.gc_frame_samples += 1
334354

335355
# Update cumulative thread status counts
336356
for key, count in temp_status_counts.items():
@@ -568,21 +588,7 @@ def _setup_colors(self):
568588
def _build_stats_list(self):
569589
"""Build and sort the statistics list."""
570590
stats_list = []
571-
572-
# Determine which data source to use based on view mode
573-
if self.view_mode == "ALL":
574-
# ALL threads - use aggregated result
575-
result_source = self.result
576-
else:
577-
# PER_THREAD mode - use specific thread result
578-
if self.current_thread_index < len(self.thread_ids):
579-
thread_id = self.thread_ids[self.current_thread_index]
580-
if thread_id in self.per_thread_data:
581-
result_source = self.per_thread_data[thread_id].result
582-
else:
583-
result_source = {}
584-
else:
585-
result_source = self.result
591+
result_source = self._get_current_result_source()
586592

587593
for func, call_counts in result_source.items():
588594
# Apply filter if set (using substring matching)

Lib/profiling/sampling/live_collector/widgets.py

Lines changed: 8 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -379,68 +379,20 @@ def _add_percentage_stat(
379379

380380
def draw_thread_status(self, line, width):
381381
"""Draw thread status statistics and GC information."""
382-
# Determine which status counts to use based on view mode
383-
if (
384-
self.collector.view_mode == "PER_THREAD"
385-
and len(self.collector.thread_ids) > 0
386-
):
387-
# Use per-thread stats for the selected thread
388-
if self.collector.current_thread_index < len(
389-
self.collector.thread_ids
390-
):
391-
thread_id = self.collector.thread_ids[
392-
self.collector.current_thread_index
393-
]
394-
if thread_id in self.collector.per_thread_data:
395-
thread_data = self.collector.per_thread_data[thread_id]
396-
status_counts = {
397-
"has_gil": thread_data.has_gil,
398-
"on_cpu": thread_data.on_cpu,
399-
"gil_requested": thread_data.gil_requested,
400-
"unknown": thread_data.unknown,
401-
"total": thread_data.total,
402-
}
403-
else:
404-
status_counts = self.collector._thread_status_counts
405-
else:
406-
status_counts = self.collector._thread_status_counts
407-
else:
408-
# Use aggregated stats
409-
status_counts = self.collector._thread_status_counts
382+
# Get status counts for current view mode
383+
thread_data = self.collector._get_current_thread_data()
384+
status_counts = thread_data.as_status_dict() if thread_data else self.collector._thread_status_counts
410385

411386
# Calculate percentages
412387
total_threads = max(1, status_counts["total"])
413388
pct_on_gil = (status_counts["has_gil"] / total_threads) * 100
414389
pct_off_gil = 100.0 - pct_on_gil
415-
pct_gil_requested = (
416-
status_counts["gil_requested"] / total_threads
417-
) * 100
390+
pct_gil_requested = (status_counts["gil_requested"] / total_threads) * 100
418391

419392
# Get GC percentage based on view mode
420-
if (
421-
self.collector.view_mode == "PER_THREAD"
422-
and len(self.collector.thread_ids) > 0
423-
):
424-
if self.collector.current_thread_index < len(
425-
self.collector.thread_ids
426-
):
427-
thread_id = self.collector.thread_ids[
428-
self.collector.current_thread_index
429-
]
430-
if thread_id in self.collector.per_thread_data:
431-
thread_data = self.collector.per_thread_data[thread_id]
432-
thread_samples = thread_data.sample_count
433-
thread_gc_samples = thread_data.gc_frame_samples
434-
else:
435-
thread_samples = 1
436-
thread_gc_samples = 0
437-
total_samples = max(1, thread_samples)
438-
pct_gc = (thread_gc_samples / total_samples) * 100
439-
else:
440-
total_samples = max(1, self.collector.total_samples)
441-
pct_gc = (
442-
self.collector._gc_frame_samples / total_samples
443-
) * 100
393+
if thread_data:
394+
total_samples = max(1, thread_data.sample_count)
395+
pct_gc = (thread_data.gc_frame_samples / total_samples) * 100
444396
else:
445397
total_samples = max(1, self.collector.total_samples)
446398
pct_gc = (self.collector._gc_frame_samples / total_samples) * 100
@@ -489,19 +441,7 @@ def draw_thread_status(self, line, width):
489441

490442
def draw_function_stats(self, line, width, stats_list):
491443
"""Draw function statistics summary."""
492-
# Determine which result set to use based on view mode
493-
if self.collector.view_mode == "PER_THREAD" and len(self.collector.thread_ids) > 0:
494-
if self.collector.current_thread_index < len(self.collector.thread_ids):
495-
thread_id = self.collector.thread_ids[self.collector.current_thread_index]
496-
if thread_id in self.collector.per_thread_data:
497-
result_set = self.collector.per_thread_data[thread_id].result
498-
else:
499-
result_set = {}
500-
else:
501-
result_set = self.collector.result
502-
else:
503-
result_set = self.collector.result
504-
444+
result_set = self.collector._get_current_result_source()
505445
total_funcs = len(result_set)
506446
funcs_shown = len(stats_list)
507447
executing_funcs = sum(

0 commit comments

Comments
 (0)