Skip to content
Merged
16 changes: 16 additions & 0 deletions Lib/test/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,22 @@ def test_basics(self):
self.assertEqual(c.setdefault('e', 5), 5)
self.assertEqual(c['e'], 5)

def test_update_reentrant_add_clears_counter(self):
c = Counter()
key = object()

class Evil(int):
def __new__(cls):
return int.__new__(cls, 0)
Comment thread
skirpichev marked this conversation as resolved.
Outdated

Comment thread
skirpichev marked this conversation as resolved.
Outdated
def __add__(self, other):
c.clear()
return NotImplemented

c[key] = Evil()
c.update([key])
self.assertEqual(c[key], 1)

def test_init(self):
self.assertEqual(list(Counter(self=42).items()), [('self', 42)])
self.assertEqual(list(Counter(iterable=42).items()), [('iterable', 42)])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gh-143004: Fix a potential use-after-free in collections.Counter.update() when user code mutates the Counter during an update.
Comment thread
skirpichev marked this conversation as resolved.
Outdated
5 changes: 5 additions & 0 deletions Modules/_collectionsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2577,7 +2577,12 @@ _collections__count_elements_impl(PyObject *module, PyObject *mapping,
if (_PyDict_SetItem_KnownHash(mapping, key, one, hash) < 0)
goto done;
} else {
/* oldval is a borrowed reference. Keep it alive across
PyNumber_Add(), which can execute arbitrary user code and
mutate (or even clear) the underlying dict. */
Py_INCREF(oldval);
newval = PyNumber_Add(oldval, one);
Py_DECREF(oldval);
if (newval == NULL)
goto done;
if (_PyDict_SetItem_KnownHash(mapping, key, newval, hash) < 0)
Expand Down
110 changes: 110 additions & 0 deletions Tools/scripts/bench_counter_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Microbenchmarks for collections.Counter.update() iterable fast path.

This is intended for quick before/after comparisons of small C-level changes.
It avoids third-party deps (e.g. pyperf) and prints simple, stable-enough stats.
Comment thread
Kaushalt2004 marked this conversation as resolved.
Outdated

Run (from repo root):
PCbuild\\amd64\\python.exe Tools\\scripts\\bench_counter_update.py

You can also override sizes:
... bench_counter_update.py --n-keys 1000 --n-elems 200000
"""

from __future__ import annotations

import argparse
import statistics
import sys
import time
from collections import Counter


def _run_timer(func, *, inner_loops: int, repeats: int) -> dict[str, float]:
# Warmup
for _ in range(5):
func()

samples = []
for _ in range(repeats):
t0 = time.perf_counter()
for _ in range(inner_loops):
func()
t1 = time.perf_counter()
samples.append(t1 - t0)

return {
"min_s": min(samples),
"mean_s": statistics.mean(samples),
"stdev_s": statistics.pstdev(samples) if len(samples) > 1 else 0.0,
"repeats": float(repeats),
"inner_loops": float(inner_loops),
}


def _format_line(name: str, stats: dict[str, float]) -> str:
# Report per-call based on min (least noisy for microbench comparisons).
per_call_ns = (stats["min_s"] / stats["inner_loops"]) * 1e9
return (
f"{name:32s} {per_call_ns:10.1f} ns/call"
f" (min={stats['min_s']:.6f}s, mean={stats['mean_s']:.6f}s, "
f"stdev={stats['stdev_s']:.6f}s, loops={int(stats['inner_loops'])}, reps={int(stats['repeats'])})"
)


def main(argv: list[str]) -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--n-keys", type=int, default=1000)
parser.add_argument("--n-elems", type=int, default=100_000)
parser.add_argument("--repeats", type=int, default=25)
parser.add_argument("--inner-loops", type=int, default=50)
args = parser.parse_args(argv)

n_keys = args.n_keys
n_elems = args.n_elems

# Data sets
keys_unique = list(range(n_keys))

# Many duplicates; all keys are within [0, n_keys)
keys_dupes = [i % n_keys for i in range(n_elems)]

# All elements hit the "oldval != NULL" branch by pre-seeding.
seeded = Counter({k: 1 for k in range(n_keys)})

def bench_unique_from_empty() -> None:
c = Counter()
c.update(keys_unique)

def bench_dupes_from_empty() -> None:
c = Counter()
c.update(keys_dupes)

def bench_dupes_all_preseeded() -> None:
c = seeded.copy()
c.update(keys_dupes)

# A string-like workload (common Counter use): update over a repeated alphabet.
alpha = ("abcdefghijklmnopqrstuvwxyz" * (n_elems // 26 + 1))[:n_elems]

def bench_string_dupes_from_empty() -> None:
c = Counter()
c.update(alpha)

print(sys.version.replace("\n", " "))
print(f"n_keys={n_keys}, n_elems={n_elems}, repeats={args.repeats}, inner_loops={args.inner_loops}")
print()

for name, fn in (
("update(unique) from empty", bench_unique_from_empty),
("update(dupes) from empty", bench_dupes_from_empty),
("update(dupes) preseeded", bench_dupes_all_preseeded),
("update(string dupes) empty", bench_string_dupes_from_empty),
):
stats = _run_timer(fn, inner_loops=args.inner_loops, repeats=args.repeats)
print(_format_line(name, stats))

return 0


if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))
Loading