Skip to content

Commit ec58e2f

Browse files
committed
feat: add interactive chunk classification mode
Add -I/--interactive flag to enable manual chunk classification during generation. Users can now classify each chunk as CARD_ONLY, FULL, CONTEXT_ONLY, or SKIP before LLM processing. - Add run_interactive_session() in pipeline/interactive.py - Integrate interactive mode into CLI and processor pipeline - Support passing pre-classified nodes to process_pipeline() Bump version to 0.2.0
1 parent 77e6608 commit ec58e2f

6 files changed

Lines changed: 365 additions & 7 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "doc2anki"
3-
version = "0.1.2"
3+
version = "0.2.0"
44
description = "Convert knowledge base documents to Anki flashcards"
55
readme = "README.md"
66
requires-python = ">=3.12"

src/doc2anki/cli.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,12 @@ def generate_cmd(
272272
"--dry-run",
273273
help="Parse and chunk only, don't call LLM",
274274
),
275+
interactive: bool = typer.Option(
276+
False,
277+
"-I",
278+
"--interactive",
279+
help="Interactively classify each chunk",
280+
),
275281
verbose: bool = typer.Option(
276282
False,
277283
"--verbose",
@@ -352,6 +358,22 @@ def generate_cmd(
352358
console.print(f"[blue]Document tree:[/blue] {tree}")
353359
console.print(f"[blue]Chunk level:[/blue] {actual_level}")
354360

361+
# Interactive classification if requested
362+
classified_nodes = None
363+
if interactive:
364+
from .pipeline import run_interactive_session
365+
366+
classified_nodes = run_interactive_session(
367+
tree=tree,
368+
level=actual_level,
369+
console=console,
370+
filename=str(file_path.name),
371+
)
372+
373+
if not classified_nodes:
374+
console.print(f"[yellow]No chunks to process for {file_path}[/yellow]")
375+
continue
376+
355377
# Process through pipeline
356378
try:
357379
chunk_contexts = process_pipeline(
@@ -360,6 +382,7 @@ def generate_cmd(
360382
max_tokens=max_tokens,
361383
global_context=global_context,
362384
include_parent_chain=include_parent_chain,
385+
classified_nodes=classified_nodes,
363386
)
364387
except Exception as e:
365388
fatal_exit(f"Failed to process pipeline for {file_path}: {e}")

src/doc2anki/pipeline/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .classifier import ChunkType, ClassifiedNode
44
from .context import ChunkWithContext
5+
from .interactive import run_interactive_session
56
from .processor import auto_detect_level, process_pipeline
67

78
__all__ = [
@@ -10,4 +11,5 @@
1011
"ChunkWithContext",
1112
"auto_detect_level",
1213
"process_pipeline",
14+
"run_interactive_session",
1315
]
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
"""Interactive chunk classification for the processing pipeline."""
2+
3+
from dataclasses import dataclass, field
4+
5+
from rich.console import Console
6+
from rich.panel import Panel
7+
from rich.table import Table
8+
from rich.syntax import Syntax
9+
10+
from doc2anki.parser.tree import DocumentTree, HeadingNode
11+
from doc2anki.parser.chunker import count_tokens
12+
13+
from .classifier import ChunkType, ClassifiedNode
14+
15+
16+
# Mapping from user input to ChunkType
17+
INPUT_MAP: dict[str, ChunkType] = {
18+
"f": ChunkType.FULL,
19+
"c": ChunkType.CARD_ONLY,
20+
"x": ChunkType.CONTEXT_ONLY,
21+
"s": ChunkType.SKIP,
22+
}
23+
24+
# Warning threshold for large chunks (tokens)
25+
LARGE_CHUNK_THRESHOLD = 2000
26+
27+
28+
@dataclass
29+
class InteractiveSession:
30+
"""Manages interactive chunk classification."""
31+
32+
tree: DocumentTree
33+
level: int
34+
nodes: list[HeadingNode] = field(default_factory=list)
35+
classified: list[ClassifiedNode] = field(default_factory=list)
36+
current_index: int = 0
37+
accumulated_tokens: int = 0
38+
39+
def __post_init__(self) -> None:
40+
"""Initialize nodes and classified list from tree."""
41+
self.nodes = self.tree.get_nodes_at_level(self.level)
42+
# Initialize all as CARD_ONLY (default)
43+
self.classified = [
44+
ClassifiedNode(node=n, chunk_type=ChunkType.CARD_ONLY)
45+
for n in self.nodes
46+
]
47+
48+
@property
49+
def total(self) -> int:
50+
"""Total number of nodes to classify."""
51+
return len(self.nodes)
52+
53+
@property
54+
def is_complete(self) -> bool:
55+
"""Check if all nodes have been classified."""
56+
return self.current_index >= self.total
57+
58+
@property
59+
def remaining(self) -> int:
60+
"""Number of remaining nodes to classify."""
61+
return self.total - self.current_index
62+
63+
def classify_current(self, chunk_type: ChunkType) -> int:
64+
"""
65+
Classify the current node and advance.
66+
67+
Returns the token count of the classified chunk.
68+
"""
69+
if self.is_complete:
70+
return 0
71+
72+
node = self.nodes[self.current_index]
73+
tokens = count_tokens(node.full_content)
74+
75+
self.classified[self.current_index].chunk_type = chunk_type
76+
77+
# Track accumulated context tokens
78+
if chunk_type in (ChunkType.FULL, ChunkType.CONTEXT_ONLY):
79+
self.accumulated_tokens += tokens
80+
81+
self.current_index += 1
82+
return tokens
83+
84+
def classify_all_remaining(self, chunk_type: ChunkType) -> None:
85+
"""Classify all remaining nodes with the same type."""
86+
while not self.is_complete:
87+
self.classify_current(chunk_type)
88+
89+
def reset(self) -> None:
90+
"""Reset classification to start over."""
91+
self.current_index = 0
92+
self.accumulated_tokens = 0
93+
for cn in self.classified:
94+
cn.chunk_type = ChunkType.CARD_ONLY
95+
96+
def get_current_node(self) -> HeadingNode | None:
97+
"""Get the current node being classified."""
98+
if self.is_complete:
99+
return None
100+
return self.nodes[self.current_index]
101+
102+
103+
def display_section_summary(
104+
console: Console,
105+
nodes: list[HeadingNode],
106+
filename: str,
107+
level: int,
108+
) -> None:
109+
"""Display a summary table of all sections."""
110+
console.print()
111+
console.print(
112+
Panel(
113+
f"Found [cyan]{len(nodes)}[/cyan] sections at level [cyan]{level}[/cyan]",
114+
title=f"[bold]Processing: {filename}[/bold]",
115+
border_style="blue",
116+
)
117+
)
118+
119+
table = Table(show_header=True, header_style="bold")
120+
table.add_column("#", style="dim", width=4, justify="right")
121+
table.add_column("Section", style="cyan")
122+
table.add_column("Tokens", justify="right")
123+
124+
for i, node in enumerate(nodes, 1):
125+
tokens = count_tokens(node.full_content)
126+
# Add warning indicator for large chunks
127+
token_str = f"{tokens:,}"
128+
if tokens > LARGE_CHUNK_THRESHOLD:
129+
token_str = f"[yellow]{tokens:,}[/yellow] [yellow]![/yellow]"
130+
table.add_row(str(i), node.title, token_str)
131+
132+
console.print(table)
133+
console.print()
134+
135+
136+
def display_classification_help(console: Console) -> None:
137+
"""Display classification options."""
138+
console.print(
139+
"[bold]Classification:[/bold] "
140+
"[green][F]ull[/green] "
141+
"[blue][C]ard only[/blue] "
142+
"[yellow]conte[X]t only[/yellow] "
143+
"[red][S]kip[/red]"
144+
)
145+
console.print()
146+
147+
148+
def preview_chunk(console: Console, node: HeadingNode) -> None:
149+
"""Display a preview of the chunk content."""
150+
content = node.full_content
151+
# Truncate if too long
152+
max_preview = 2000
153+
if len(content) > max_preview:
154+
content = content[:max_preview] + "\n... [dim](truncated)[/dim]"
155+
156+
# Detect syntax for highlighting
157+
syntax = Syntax(content, "markdown", theme="monokai", line_numbers=True)
158+
console.print(Panel(syntax, title=f"[bold]{node.title}[/bold]", border_style="cyan"))
159+
160+
161+
def prompt_classification(
162+
console: Console,
163+
session: InteractiveSession,
164+
) -> ChunkType | str:
165+
"""
166+
Prompt user for classification of current node.
167+
168+
Returns:
169+
ChunkType if valid classification, or str for special commands.
170+
"""
171+
node = session.get_current_node()
172+
if node is None:
173+
return "done"
174+
175+
tokens = count_tokens(node.full_content)
176+
idx = session.current_index + 1
177+
total = session.total
178+
179+
# Build prompt
180+
token_info = f"[dim]({tokens:,} tokens)[/dim]"
181+
if tokens > LARGE_CHUNK_THRESHOLD:
182+
token_info = f"[yellow]({tokens:,} tokens)[/yellow]"
183+
184+
console.print(
185+
f"Section [bold]{idx}/{total}[/bold] "
186+
f"[cyan]\"{node.title}\"[/cyan] {token_info}"
187+
)
188+
189+
try:
190+
response = console.input(
191+
"[dim][F/C/X/S/preview/all:?/done][/dim] (default: C): "
192+
).strip().lower()
193+
except (KeyboardInterrupt, EOFError):
194+
console.print("\n[yellow]Interrupted. Exiting...[/yellow]")
195+
raise SystemExit(1)
196+
197+
if not response:
198+
return ChunkType.CARD_ONLY
199+
200+
# Check for single-letter classification
201+
if response in INPUT_MAP:
202+
return INPUT_MAP[response]
203+
204+
# Check for special commands
205+
if response in ("p", "preview"):
206+
return "preview"
207+
if response == "reset":
208+
return "reset"
209+
if response == "done":
210+
return "done"
211+
if response.startswith("all:"):
212+
return response # Pass through for batch handling
213+
214+
# Invalid input, default to CARD_ONLY
215+
console.print(f"[dim]Unknown input '{response}', using Card only[/dim]")
216+
return ChunkType.CARD_ONLY
217+
218+
219+
def handle_batch_command(command: str) -> ChunkType | None:
220+
"""
221+
Parse batch command like 'all:C'.
222+
223+
Returns ChunkType if valid, None otherwise.
224+
"""
225+
if not command.startswith("all:"):
226+
return None
227+
228+
type_char = command[4:].lower()
229+
return INPUT_MAP.get(type_char)
230+
231+
232+
def show_token_info(
233+
console: Console,
234+
chunk_tokens: int,
235+
accumulated_tokens: int,
236+
) -> None:
237+
"""Show token info after adding to context."""
238+
console.print(
239+
f" [dim]This chunk:[/dim] {chunk_tokens:,} tokens\n"
240+
f" [dim]Accumulated context:[/dim] {accumulated_tokens:,} tokens"
241+
)
242+
243+
244+
def run_interactive_session(
245+
tree: DocumentTree,
246+
level: int,
247+
console: Console,
248+
filename: str = "",
249+
) -> list[ClassifiedNode]:
250+
"""
251+
Run an interactive classification session.
252+
253+
Args:
254+
tree: DocumentTree to classify
255+
level: Heading level to classify at
256+
console: Rich console for output
257+
filename: Source filename for display
258+
259+
Returns:
260+
List of ClassifiedNode with user classifications
261+
"""
262+
session = InteractiveSession(tree=tree, level=level)
263+
264+
if session.total == 0:
265+
console.print("[yellow]No sections found at this level.[/yellow]")
266+
return []
267+
268+
# Display summary
269+
display_section_summary(console, session.nodes, filename, level)
270+
display_classification_help(console)
271+
272+
# Classification loop
273+
while not session.is_complete:
274+
result = prompt_classification(console, session)
275+
276+
if isinstance(result, ChunkType):
277+
chunk_tokens = session.classify_current(result)
278+
# Show token info if adding to context
279+
if result in (ChunkType.FULL, ChunkType.CONTEXT_ONLY):
280+
show_token_info(console, chunk_tokens, session.accumulated_tokens)
281+
console.print()
282+
283+
elif result == "preview":
284+
node = session.get_current_node()
285+
if node:
286+
preview_chunk(console, node)
287+
288+
elif result == "reset":
289+
session.reset()
290+
console.print("[yellow]Reset. Starting over...[/yellow]\n")
291+
display_section_summary(console, session.nodes, filename, level)
292+
display_classification_help(console)
293+
294+
elif result == "done":
295+
# Mark all remaining as default (CARD_ONLY)
296+
console.print(
297+
f"[dim]Marking remaining {session.remaining} sections as Card only[/dim]"
298+
)
299+
session.classify_all_remaining(ChunkType.CARD_ONLY)
300+
301+
elif result.startswith("all:"):
302+
chunk_type = handle_batch_command(result)
303+
if chunk_type:
304+
console.print(
305+
f"[dim]Classifying remaining {session.remaining} sections as {chunk_type.value}[/dim]"
306+
)
307+
session.classify_all_remaining(chunk_type)
308+
else:
309+
console.print(f"[red]Invalid batch command: {result}[/red]")
310+
311+
# Summary
312+
console.print(Panel("[green]Classification complete![/green]", border_style="green"))
313+
314+
# Count classifications
315+
counts = {t: 0 for t in ChunkType}
316+
for cn in session.classified:
317+
counts[cn.chunk_type] += 1
318+
319+
console.print(
320+
f" [green]Full:[/green] {counts[ChunkType.FULL]} "
321+
f"[blue]Card only:[/blue] {counts[ChunkType.CARD_ONLY]} "
322+
f"[yellow]Context only:[/yellow] {counts[ChunkType.CONTEXT_ONLY]} "
323+
f"[red]Skip:[/red] {counts[ChunkType.SKIP]}"
324+
)
325+
console.print()
326+
327+
return session.classified

0 commit comments

Comments
 (0)